diff --git a/LlamaFactory/.github/ISSUE_TEMPLATE/1-bug-report.yml b/LlamaFactory/.github/ISSUE_TEMPLATE/1-bug-report.yml new file mode 100644 index 0000000000000000000000000000000000000000..a08596faa5b3be2545412d372f7bdeadca95afb4 --- /dev/null +++ b/LlamaFactory/.github/ISSUE_TEMPLATE/1-bug-report.yml @@ -0,0 +1,61 @@ +name: "\U0001F41B Bug / help" +description: Create a report to help us improve the LLaMA Factory +labels: ["bug", "pending"] +body: + - type: markdown + attributes: + value: | + Issues included in **[FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** or those with **insufficient** information may be closed without a response. + 已经包含在 **[常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** 内或提供信息**不完整**的 issues 可能不会被回复。 + + - type: markdown + attributes: + value: | + Please do not create issues that are not related to framework bugs under this category, use **[Discussions](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)** instead. + 请勿在此分类下创建和框架 bug 无关的 issues,训练问题求助请使用 **[讨论区](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)**。 + + - type: checkboxes + id: reminder + attributes: + label: Reminder + description: | + Please ensure you have read the above rules carefully and searched the existing issues (including FAQs). + 请确保您已经认真阅读了上述规则并且搜索过现有的 issues(包括常见问题)。 + + options: + - label: I have read the above rules and searched the existing issues. + required: true + + - type: textarea + id: system-info + validations: + required: true + attributes: + label: System Info + description: | + Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below. + 请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。 + + placeholder: llamafactory version, platform, python version, ... + + - type: textarea + id: reproduction + validations: + required: true + attributes: + label: Reproduction + description: | + Please provide entry arguments, error messages and stack traces that reproduces the problem. + 请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。 + + value: | + ```text + Put your message here. + ``` + + - type: textarea + id: others + validations: + required: false + attributes: + label: Others diff --git a/LlamaFactory/.github/ISSUE_TEMPLATE/2-feature-request.yml b/LlamaFactory/.github/ISSUE_TEMPLATE/2-feature-request.yml new file mode 100644 index 0000000000000000000000000000000000000000..5d72271ebc8db3d10bf7e9c6af209e857566bde6 --- /dev/null +++ b/LlamaFactory/.github/ISSUE_TEMPLATE/2-feature-request.yml @@ -0,0 +1,41 @@ +name: "\U0001F680 Feature request" +description: Submit a request for a new feature +labels: ["enhancement", "pending"] +body: + - type: markdown + attributes: + value: | + Please do not create issues that are not related to new features under this category. + 请勿在此分类下创建和新特性无关的 issues。 + + - type: checkboxes + id: reminder + attributes: + label: Reminder + description: | + Please ensure you have read the above rules carefully and searched the existing issues. + 请确保您已经认真阅读了上述规则并且搜索过现有的 issues。 + + options: + - label: I have read the above rules and searched the existing issues. + required: true + + - type: textarea + id: description + validations: + required: true + attributes: + label: Description + description: | + A clear and concise description of the feature proposal. + 请详细描述您希望加入的新功能特性。 + + - type: textarea + id: contribution + validations: + required: false + attributes: + label: Pull Request + description: | + Have you already created the relevant PR and submitted the code? + 您是否已经创建了相关 PR 并提交了代码? diff --git a/LlamaFactory/.github/ISSUE_TEMPLATE/config.yml b/LlamaFactory/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..1a7719634963d9d78bfa5155b51c5a82311084e4 --- /dev/null +++ b/LlamaFactory/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: 📚 FAQs | 常见问题 + url: https://github.com/hiyouga/LLaMA-Factory/issues/4614 + about: Reading in advance is recommended | 建议提前阅读 + - name: Discussions | 讨论区 + url: https://github.com/hiyouga/LLaMA-Factory/discussions + about: Please ask fine-tuning questions here | 请在这里讨论训练问题 diff --git a/LlamaFactory/.github/workflows/docker.yml b/LlamaFactory/.github/workflows/docker.yml new file mode 100644 index 0000000000000000000000000000000000000000..fea0a92776530571c7733e70c76216a09aeb4d12 --- /dev/null +++ b/LlamaFactory/.github/workflows/docker.yml @@ -0,0 +1,116 @@ +name: docker + +on: + workflow_dispatch: + push: + branches: + - "main" + paths: + - "**/*.py" + - "pyproject.toml" + - "docker/**" + - ".github/workflows/*.yml" + pull_request: + branches: + - "main" + paths: + - "**/*.py" + - "pyproject.toml" + - "docker/**" + - ".github/workflows/*.yml" + release: + types: + - published + +jobs: + build: + strategy: + fail-fast: false + matrix: + include: + - device: "cuda" + - device: "npu-a2" + - device: "npu-a3" + + runs-on: ubuntu-latest + + concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + + environment: + name: docker + url: https://hub.docker.com/r/hiyouga/llamafactory + + steps: + - name: Free up disk space + uses: jlumbroso/free-disk-space@v1.3.1 + with: + tool-cache: true + docker-images: false + + - name: Checkout + uses: actions/checkout@v6 + + - name: Get llamafactory version + id: version + run: | + if [ "${{ github.event_name }}" = "release" ]; then + echo "tag=$(grep -oP 'VERSION = "\K[^"]+' src/llamafactory/extras/env.py)" >> "$GITHUB_OUTPUT" + else + echo "tag=latest" >> "$GITHUB_OUTPUT" + fi + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + if: ${{ github.event_name != 'pull_request' }} + uses: docker/login-action@v3 + with: + username: ${{ vars.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Login to Quay + if: ${{ github.event_name != 'pull_request' && startsWith(matrix.device, 'npu') }} + uses: docker/login-action@v3 + with: + registry: quay.io + username: ${{ vars.QUAY_ASCEND_USERNAME }} + password: ${{ secrets.QUAY_ASCEND_TOKEN }} + + - name: Build and push Docker image (CUDA) + if: ${{ matrix.device == 'cuda' }} + uses: docker/build-push-action@v6 + with: + context: . + file: ./docker/docker-cuda/Dockerfile + push: ${{ github.event_name != 'pull_request' }} + tags: | + docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }} + + - name: Build and push Docker image (NPU-A2) + if: ${{ matrix.device == 'npu-a2' }} + uses: docker/build-push-action@v6 + with: + context: . + platforms: linux/amd64,linux/arm64 + file: ./docker/docker-npu/Dockerfile + push: ${{ github.event_name != 'pull_request' }} + tags: | + docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a2 + quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2 + + - name: Build and push Docker image (NPU-A3) + if: ${{ matrix.device == 'npu-a3' }} + uses: docker/build-push-action@v6 + with: + context: . + platforms: linux/amd64,linux/arm64 + file: ./docker/docker-npu/Dockerfile + build-args: | + BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11 + push: ${{ github.event_name != 'pull_request' }} + tags: | + docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a3 + quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a3 diff --git a/LlamaFactory/.github/workflows/publish.yml b/LlamaFactory/.github/workflows/publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..41cbff65544e4922cfe6a770005467a005d59aa1 --- /dev/null +++ b/LlamaFactory/.github/workflows/publish.yml @@ -0,0 +1,37 @@ +name: publish + +on: + workflow_dispatch: + release: + types: + - published + +jobs: + publish: + name: Upload release to PyPI + + runs-on: ubuntu-latest + + environment: + name: release + url: https://pypi.org/p/llamafactory + + permissions: + id-token: write + + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Install uv + uses: astral-sh/setup-uv@v7 + with: + python-version: "3.11" + github-token: ${{ github.token }} + + - name: Build package + run: | + make build + + - name: Publish package + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/LlamaFactory/src/api.py b/LlamaFactory/src/api.py new file mode 100644 index 0000000000000000000000000000000000000000..61215459ed91c6fa529a719cb9dac57223754d2e --- /dev/null +++ b/LlamaFactory/src/api.py @@ -0,0 +1,33 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import uvicorn + +from llamafactory.api.app import create_app +from llamafactory.chat import ChatModel + + +def main(): + chat_model = ChatModel() + app = create_app(chat_model) + api_host = os.getenv("API_HOST", "0.0.0.0") + api_port = int(os.getenv("API_PORT", "8000")) + print(f"Visit http://localhost:{api_port}/docs for API document.") + uvicorn.run(app, host=api_host, port=api_port) + + +if __name__ == "__main__": + main() diff --git a/LlamaFactory/src/llamafactory/__init__.py b/LlamaFactory/src/llamafactory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1567ef572714881cc464db25d3da3d08a460963 --- /dev/null +++ b/LlamaFactory/src/llamafactory/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Efficient fine-tuning of large language models. + +Level: + api, webui > chat, eval, train > data, model > hparams > extras + +Disable version checking: DISABLE_VERSION_CHECK=1 +Enable VRAM recording: RECORD_VRAM=1 +Force using torchrun: FORCE_TORCHRUN=1 +Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN +Use modelscope: USE_MODELSCOPE_HUB=1 +Use openmind: USE_OPENMIND_HUB=1 +""" + +from .extras.env import VERSION + + +__version__ = VERSION diff --git a/LlamaFactory/src/llamafactory/__pycache__/__init__.cpython-311.pyc b/LlamaFactory/src/llamafactory/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5baa0e919b345df025d2d0fefcaa8bc97dc4b6be Binary files /dev/null and b/LlamaFactory/src/llamafactory/__pycache__/__init__.cpython-311.pyc differ diff --git a/LlamaFactory/src/llamafactory/__pycache__/__init__.cpython-312.pyc b/LlamaFactory/src/llamafactory/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a37e0f7ca8fee40ccbdaf994d7137193c67bf576 Binary files /dev/null and b/LlamaFactory/src/llamafactory/__pycache__/__init__.cpython-312.pyc differ diff --git a/LlamaFactory/src/llamafactory/__pycache__/cli.cpython-311.pyc b/LlamaFactory/src/llamafactory/__pycache__/cli.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a1b2794395848764e3378aea7114e3f8a979c68 Binary files /dev/null and b/LlamaFactory/src/llamafactory/__pycache__/cli.cpython-311.pyc differ diff --git a/LlamaFactory/src/llamafactory/__pycache__/cli.cpython-312.pyc b/LlamaFactory/src/llamafactory/__pycache__/cli.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47f0ae90458bd19e27cede1f0f484bff5b744c67 Binary files /dev/null and b/LlamaFactory/src/llamafactory/__pycache__/cli.cpython-312.pyc differ diff --git a/LlamaFactory/src/llamafactory/__pycache__/launcher.cpython-311.pyc b/LlamaFactory/src/llamafactory/__pycache__/launcher.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ac40e601e824808047c861384e2b0ebea9b7ecb Binary files /dev/null and b/LlamaFactory/src/llamafactory/__pycache__/launcher.cpython-311.pyc differ diff --git a/LlamaFactory/src/llamafactory/__pycache__/launcher.cpython-312.pyc b/LlamaFactory/src/llamafactory/__pycache__/launcher.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43ed331273537212e37d6ae392f2fcab6b86d334 Binary files /dev/null and b/LlamaFactory/src/llamafactory/__pycache__/launcher.cpython-312.pyc differ diff --git a/LlamaFactory/src/llamafactory/api/__init__.py b/LlamaFactory/src/llamafactory/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/LlamaFactory/src/llamafactory/api/__pycache__/common.cpython-311.pyc b/LlamaFactory/src/llamafactory/api/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0767dfc39c98ccf34d788e60d27b79aee5caf98 Binary files /dev/null and b/LlamaFactory/src/llamafactory/api/__pycache__/common.cpython-311.pyc differ diff --git a/LlamaFactory/src/llamafactory/api/__pycache__/protocol.cpython-311.pyc b/LlamaFactory/src/llamafactory/api/__pycache__/protocol.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7e3d025e8de5e0cdaa964dc4e72782385267def Binary files /dev/null and b/LlamaFactory/src/llamafactory/api/__pycache__/protocol.cpython-311.pyc differ diff --git a/LlamaFactory/src/llamafactory/api/app.py b/LlamaFactory/src/llamafactory/api/app.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec0679cb7e053058f52bdbf947cb13e554c5ca8 --- /dev/null +++ b/LlamaFactory/src/llamafactory/api/app.py @@ -0,0 +1,133 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from contextlib import asynccontextmanager +from functools import partial +from typing import Annotated + +from ..chat import ChatModel +from ..extras.constants import EngineName +from ..extras.misc import torch_gc +from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available +from .chat import ( + create_chat_completion_response, + create_score_evaluation_response, + create_stream_chat_completion_response, +) +from .protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ModelCard, + ModelList, + ScoreEvaluationRequest, + ScoreEvaluationResponse, +) + + +if is_fastapi_available(): + from fastapi import Depends, FastAPI, HTTPException, status + from fastapi.middleware.cors import CORSMiddleware + from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + + +if is_starlette_available(): + from sse_starlette import EventSourceResponse + + +if is_uvicorn_available(): + import uvicorn + + +async def sweeper() -> None: + while True: + torch_gc() + await asyncio.sleep(300) + + +@asynccontextmanager +async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory + if chat_model.engine.name == EngineName.HF: + asyncio.create_task(sweeper()) + + yield + torch_gc() + + +def create_app(chat_model: "ChatModel") -> "FastAPI": + root_path = os.getenv("FASTAPI_ROOT_PATH", "") + app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + api_key = os.getenv("API_KEY") + security = HTTPBearer(auto_error=False) + + async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]): + if api_key and (auth is None or auth.credentials != api_key): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.") + + @app.get( + "/v1/models", + response_model=ModelList, + status_code=status.HTTP_200_OK, + dependencies=[Depends(verify_api_key)], + ) + async def list_models(): + model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo")) + return ModelList(data=[model_card]) + + @app.post( + "/v1/chat/completions", + response_model=ChatCompletionResponse, + status_code=status.HTTP_200_OK, + dependencies=[Depends(verify_api_key)], + ) + async def create_chat_completion(request: ChatCompletionRequest): + if not chat_model.engine.can_generate: + raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") + + if request.stream: + generate = create_stream_chat_completion_response(request, chat_model) + return EventSourceResponse(generate, media_type="text/event-stream", sep="\n") + else: + return await create_chat_completion_response(request, chat_model) + + @app.post( + "/v1/score/evaluation", + response_model=ScoreEvaluationResponse, + status_code=status.HTTP_200_OK, + dependencies=[Depends(verify_api_key)], + ) + async def create_score_evaluation(request: ScoreEvaluationRequest): + if chat_model.engine.can_generate: + raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") + + return await create_score_evaluation_response(request, chat_model) + + return app + + +def run_api() -> None: + chat_model = ChatModel() + app = create_app(chat_model) + api_host = os.getenv("API_HOST", "0.0.0.0") + api_port = int(os.getenv("API_PORT", "8000")) + print(f"Visit http://localhost:{api_port}/docs for API document.") + uvicorn.run(app, host=api_host, port=api_port) diff --git a/LlamaFactory/src/llamafactory/api/chat.py b/LlamaFactory/src/llamafactory/api/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..93236c5ca865492f0c45e1f5ab56a389875350ea --- /dev/null +++ b/LlamaFactory/src/llamafactory/api/chat.py @@ -0,0 +1,291 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import io +import json +import os +import re +import uuid +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Optional + +from ..data import Role as DataRole +from ..extras import logging +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER +from ..extras.misc import is_env_enabled +from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available +from .common import check_lfi_path, check_ssrf_url, dictify, jsonify +from .protocol import ( + ChatCompletionMessage, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseUsage, + ChatCompletionStreamResponse, + ChatCompletionStreamResponseChoice, + Finish, + Function, + FunctionCall, + Role, + ScoreEvaluationResponse, +) + + +if is_fastapi_available(): + from fastapi import HTTPException, status + + +if is_pillow_available(): + from PIL import Image + + +if is_requests_available(): + import requests + + +if TYPE_CHECKING: + from ..chat import ChatModel + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from .protocol import ChatCompletionRequest, ScoreEvaluationRequest + + +logger = logging.get_logger(__name__) +ROLE_MAPPING = { + Role.USER: DataRole.USER.value, + Role.ASSISTANT: DataRole.ASSISTANT.value, + Role.SYSTEM: DataRole.SYSTEM.value, + Role.FUNCTION: DataRole.FUNCTION.value, + Role.TOOL: DataRole.OBSERVATION.value, +} + + +def _process_request( + request: "ChatCompletionRequest", +) -> tuple[ + list[dict[str, str]], + Optional[str], + Optional[str], + Optional[list["ImageInput"]], + Optional[list["VideoInput"]], + Optional[list["AudioInput"]], +]: + if is_env_enabled("API_VERBOSE", "1"): + logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") + + if len(request.messages) == 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") + + if request.messages[0].role == Role.SYSTEM: + content = request.messages.pop(0).content + system = content[0].text if isinstance(content, list) else content + else: + system = None + + if len(request.messages) % 2 == 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") + + input_messages = [] + images, videos, audios = [], [], [] + for i, message in enumerate(request.messages): + if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + + if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): + tool_calls = [ + {"name": tool_call.function.name, "arguments": tool_call.function.arguments} + for tool_call in message.tool_calls + ] + content = json.dumps(tool_calls, ensure_ascii=False) + input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) + elif isinstance(message.content, list): + text_content = "" + for input_item in message.content: + if input_item.type == "text": + text_content += input_item.text + elif input_item.type == "image_url": + text_content += IMAGE_PLACEHOLDER + image_url = input_item.image_url.url + if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image + image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1])) + elif os.path.isfile(image_url): # local file + check_lfi_path(image_url) + image_stream = open(image_url, "rb") + else: # web uri + check_ssrf_url(image_url) + image_stream = requests.get(image_url, stream=True).raw + + images.append(Image.open(image_stream).convert("RGB")) + elif input_item.type == "video_url": + text_content += VIDEO_PLACEHOLDER + video_url = input_item.video_url.url + if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video + video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1])) + elif os.path.isfile(video_url): # local file + check_lfi_path(video_url) + video_stream = video_url + else: # web uri + check_ssrf_url(video_url) + video_stream = requests.get(video_url, stream=True).raw + + videos.append(video_stream) + elif input_item.type == "audio_url": + text_content += AUDIO_PLACEHOLDER + audio_url = input_item.audio_url.url + if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio + audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1])) + elif os.path.isfile(audio_url): # local file + check_lfi_path(audio_url) + audio_stream = audio_url + else: # web uri + check_ssrf_url(audio_url) + audio_stream = requests.get(audio_url, stream=True).raw + + audios.append(audio_stream) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}." + ) + + input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content}) + else: + input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) + + tool_list = request.tools + if isinstance(tool_list, list) and len(tool_list): + try: + tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) + except json.JSONDecodeError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") + else: + tools = None + + return input_messages, system, tools, images or None, videos or None, audios or None + + +def _create_stream_chat_completion_chunk( + completion_id: str, + model: str, + delta: "ChatCompletionMessage", + index: Optional[int] = 0, + finish_reason: Optional["Finish"] = None, +) -> str: + choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) + chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data]) + return jsonify(chunk) + + +async def create_chat_completion_response( + request: "ChatCompletionRequest", chat_model: "ChatModel" +) -> "ChatCompletionResponse": + completion_id = f"chatcmpl-{uuid.uuid4().hex}" + input_messages, system, tools, images, videos, audios = _process_request(request) + responses = await chat_model.achat( + input_messages, + system, + tools, + images, + videos, + audios, + do_sample=request.do_sample, + temperature=request.temperature, + top_p=request.top_p, + max_new_tokens=request.max_tokens, + num_return_sequences=request.n, + repetition_penalty=request.presence_penalty, + stop=request.stop, + ) + + prompt_length, response_length = 0, 0 + choices = [] + for i, response in enumerate(responses): + if tools: + result = chat_model.engine.template.extract_tool(response.response_text) + else: + result = response.response_text + + if isinstance(result, list): + tool_calls = [] + for tool in result: + function = Function(name=tool.name, arguments=tool.arguments) + tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function)) + + response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) + finish_reason = Finish.TOOL + else: + response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) + finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH + + choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)) + prompt_length = response.prompt_length + response_length += response.response_length + + usage = ChatCompletionResponseUsage( + prompt_tokens=prompt_length, + completion_tokens=response_length, + total_tokens=prompt_length + response_length, + ) + + return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage) + + +async def create_stream_chat_completion_response( + request: "ChatCompletionRequest", chat_model: "ChatModel" +) -> AsyncGenerator[str, None]: + completion_id = f"chatcmpl-{uuid.uuid4().hex}" + input_messages, system, tools, images, videos, audios = _process_request(request) + if tools: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") + + if request.n > 1: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.") + + yield _create_stream_chat_completion_chunk( + completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="") + ) + async for new_token in chat_model.astream_chat( + input_messages, + system, + tools, + images, + videos, + audios, + do_sample=request.do_sample, + temperature=request.temperature, + top_p=request.top_p, + max_new_tokens=request.max_tokens, + repetition_penalty=request.presence_penalty, + stop=request.stop, + ): + if len(new_token) != 0: + yield _create_stream_chat_completion_chunk( + completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token) + ) + + yield _create_stream_chat_completion_chunk( + completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP + ) + yield "[DONE]" + + +async def create_score_evaluation_response( + request: "ScoreEvaluationRequest", chat_model: "ChatModel" +) -> "ScoreEvaluationResponse": + score_id = f"scoreval-{uuid.uuid4().hex}" + if len(request.messages) == 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") + + scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) + return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores) diff --git a/LlamaFactory/src/llamafactory/api/common.py b/LlamaFactory/src/llamafactory/api/common.py new file mode 100644 index 0000000000000000000000000000000000000000..7b4e9602de7ebc10b4f15c68ad9167cb9d80d8ef --- /dev/null +++ b/LlamaFactory/src/llamafactory/api/common.py @@ -0,0 +1,96 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ipaddress +import json +import os +import socket +from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse + +from ..extras.misc import is_env_enabled +from ..extras.packages import is_fastapi_available + + +if is_fastapi_available(): + from fastapi import HTTPException, status + + +if TYPE_CHECKING: + from pydantic import BaseModel + + +SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media")) +ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1") + + +def dictify(data: "BaseModel") -> dict[str, Any]: + try: # pydantic v2 + return data.model_dump(exclude_unset=True) + except AttributeError: # pydantic v1 + return data.dict(exclude_unset=True) + + +def jsonify(data: "BaseModel") -> str: + try: # pydantic v2 + return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) + except AttributeError: # pydantic v1 + return data.json(exclude_unset=True, ensure_ascii=False) + + +def check_lfi_path(path: str) -> None: + """Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe.""" + if not ALLOW_LOCAL_FILES: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.") + + try: + os.makedirs(SAFE_MEDIA_PATH, exist_ok=True) + real_path = os.path.realpath(path) + safe_path = os.path.realpath(SAFE_MEDIA_PATH) + + if not real_path.startswith(safe_path): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory." + ) + except Exception: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.") + + +def check_ssrf_url(url: str) -> None: + """Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe.""" + try: + parsed_url = urlparse(url) + if parsed_url.scheme not in ["http", "https"]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.") + + hostname = parsed_url.hostname + if not hostname: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.") + + ip_info = socket.getaddrinfo(hostname, parsed_url.port) + ip_address_str = ip_info[0][4][0] + ip = ipaddress.ip_address(ip_address_str) + + if not ip.is_global: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access to private or reserved IP addresses is not allowed.", + ) + + except socket.gaierror: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}" + ) + except Exception as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}") diff --git a/LlamaFactory/src/llamafactory/api/protocol.py b/LlamaFactory/src/llamafactory/api/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..675523f062316f3e332d13884e7322aa60050905 --- /dev/null +++ b/LlamaFactory/src/llamafactory/api/protocol.py @@ -0,0 +1,156 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from enum import Enum, unique +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +@unique +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + TOOL = "tool" + + +@unique +class Finish(str, Enum): + STOP = "stop" + LENGTH = "length" + TOOL = "tool_calls" + + +class ModelCard(BaseModel): + id: str + object: Literal["model"] = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: Literal["owner"] = "owner" + + +class ModelList(BaseModel): + object: Literal["list"] = "list" + data: list[ModelCard] = [] + + +class Function(BaseModel): + name: str + arguments: str + + +class FunctionDefinition(BaseModel): + name: str + description: str + parameters: dict[str, Any] + + +class FunctionAvailable(BaseModel): + type: Literal["function", "code_interpreter"] = "function" + function: FunctionDefinition | None = None + + +class FunctionCall(BaseModel): + id: str + type: Literal["function"] = "function" + function: Function + + +class URL(BaseModel): + url: str + detail: Literal["auto", "low", "high"] = "auto" + + +class MultimodalInputItem(BaseModel): + type: Literal["text", "image_url", "video_url", "audio_url"] + text: str | None = None + image_url: URL | None = None + video_url: URL | None = None + audio_url: URL | None = None + + +class ChatMessage(BaseModel): + role: Role + content: str | list[MultimodalInputItem] | None = None + tool_calls: list[FunctionCall] | None = None + + +class ChatCompletionMessage(BaseModel): + role: Role | None = None + content: str | None = None + tool_calls: list[FunctionCall] | None = None + + +class ChatCompletionRequest(BaseModel): + model: str + messages: list[ChatMessage] + tools: list[FunctionAvailable] | None = None + do_sample: bool | None = None + temperature: float | None = None + top_p: float | None = None + n: int = 1 + presence_penalty: float | None = None + max_tokens: int | None = None + stop: str | list[str] | None = None + stream: bool = False + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatCompletionMessage + finish_reason: Finish + + +class ChatCompletionStreamResponseChoice(BaseModel): + index: int + delta: ChatCompletionMessage + finish_reason: Finish | None = None + + +class ChatCompletionResponseUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionResponse(BaseModel): + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[ChatCompletionResponseChoice] + usage: ChatCompletionResponseUsage + + +class ChatCompletionStreamResponse(BaseModel): + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[ChatCompletionStreamResponseChoice] + + +class ScoreEvaluationRequest(BaseModel): + model: str + messages: list[str] + max_length: int | None = None + + +class ScoreEvaluationResponse(BaseModel): + id: str + object: Literal["score.evaluation"] = "score.evaluation" + model: str + scores: list[float] diff --git a/LlamaFactory/src/llamafactory/chat/__init__.py b/LlamaFactory/src/llamafactory/chat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15d8b9ba2d77d6f300d59300da5a49abd3ed4e57 --- /dev/null +++ b/LlamaFactory/src/llamafactory/chat/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base_engine import BaseEngine +from .chat_model import ChatModel + + +__all__ = ["BaseEngine", "ChatModel"] diff --git a/LlamaFactory/src/llamafactory/chat/__pycache__/__init__.cpython-311.pyc b/LlamaFactory/src/llamafactory/chat/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cee02d9d1a045f8a54416a215595cbb48e5816f Binary files /dev/null and b/LlamaFactory/src/llamafactory/chat/__pycache__/__init__.cpython-311.pyc differ diff --git a/LlamaFactory/src/llamafactory/chat/__pycache__/__init__.cpython-312.pyc b/LlamaFactory/src/llamafactory/chat/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3da5891b54cec08c062fe652998947ccaf1fcaa Binary files /dev/null and b/LlamaFactory/src/llamafactory/chat/__pycache__/__init__.cpython-312.pyc differ diff --git a/LlamaFactory/src/llamafactory/chat/__pycache__/base_engine.cpython-311.pyc b/LlamaFactory/src/llamafactory/chat/__pycache__/base_engine.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaa3ffce5a899d999629eb9bdf1ad3b17e564ba4 Binary files /dev/null and b/LlamaFactory/src/llamafactory/chat/__pycache__/base_engine.cpython-311.pyc differ diff --git a/LlamaFactory/src/llamafactory/chat/__pycache__/base_engine.cpython-312.pyc b/LlamaFactory/src/llamafactory/chat/__pycache__/base_engine.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d40dd46f9441465373d3735d1ca171547a729d0 Binary files /dev/null and b/LlamaFactory/src/llamafactory/chat/__pycache__/base_engine.cpython-312.pyc differ diff --git a/LlamaFactory/src/llamafactory/chat/__pycache__/chat_model.cpython-311.pyc b/LlamaFactory/src/llamafactory/chat/__pycache__/chat_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4008b23e56e416e6e7bacc42f2ba4afb375fcc7b Binary files /dev/null and b/LlamaFactory/src/llamafactory/chat/__pycache__/chat_model.cpython-311.pyc differ diff --git a/LlamaFactory/src/llamafactory/chat/__pycache__/chat_model.cpython-312.pyc b/LlamaFactory/src/llamafactory/chat/__pycache__/chat_model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df0f0ea481e3dce319c80ac645d08ef82d1b9243 Binary files /dev/null and b/LlamaFactory/src/llamafactory/chat/__pycache__/chat_model.cpython-312.pyc differ diff --git a/LlamaFactory/src/llamafactory/chat/__pycache__/hf_engine.cpython-311.pyc b/LlamaFactory/src/llamafactory/chat/__pycache__/hf_engine.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..991ca0f69dfe2b759c640660ad7fb4a2f6d2e08f Binary files /dev/null and b/LlamaFactory/src/llamafactory/chat/__pycache__/hf_engine.cpython-311.pyc differ diff --git a/LlamaFactory/src/llamafactory/chat/__pycache__/hf_engine.cpython-312.pyc b/LlamaFactory/src/llamafactory/chat/__pycache__/hf_engine.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6b47bbaa4dcec453edf6eae8fc3e065361dfa1b Binary files /dev/null and b/LlamaFactory/src/llamafactory/chat/__pycache__/hf_engine.cpython-312.pyc differ diff --git a/LlamaFactory/src/llamafactory/chat/base_engine.py b/LlamaFactory/src/llamafactory/chat/base_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..6d497c1ae927f94f396c18833b18cdb894cbd59d --- /dev/null +++ b/LlamaFactory/src/llamafactory/chat/base_engine.py @@ -0,0 +1,98 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + from vllm import AsyncLLMEngine + + from ..data import Template + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from ..extras.constants import EngineName + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +@dataclass +class Response: + response_text: str + response_length: int + prompt_length: int + finish_reason: Literal["stop", "length"] + + +class BaseEngine(ABC): + r"""Base class for inference engine of chat models. + + Must implements async methods: chat(), stream_chat() and get_scores(). + """ + + name: "EngineName" + model: Union["PreTrainedModel", "AsyncLLMEngine"] + tokenizer: "PreTrainedTokenizer" + can_generate: bool + template: "Template" + generating_args: dict[str, Any] + + @abstractmethod + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + r"""Initialize an inference engine.""" + ... + + @abstractmethod + async def chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + r"""Get a list of responses of the chat model.""" + ... + + @abstractmethod + async def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + r"""Get the response token-by-token of the chat model.""" + ... + + @abstractmethod + async def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + r"""Get a list of scores of the reward model.""" + ... diff --git a/LlamaFactory/src/llamafactory/chat/chat_model.py b/LlamaFactory/src/llamafactory/chat/chat_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cb612f88d468d76f06eefa45b96c1bfa0351fa7c --- /dev/null +++ b/LlamaFactory/src/llamafactory/chat/chat_model.py @@ -0,0 +1,210 @@ +# Copyright 2025 THUDM and the LlamaFactory team. +# +# This code is inspired by the THUDM's ChatGLM implementation. +# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from collections.abc import AsyncGenerator, Generator +from threading import Thread +from typing import TYPE_CHECKING, Any, Optional + +from ..extras.constants import EngineName +from ..extras.misc import torch_gc +from ..hparams import get_infer_args + + +if TYPE_CHECKING: + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from .base_engine import BaseEngine, Response + + +def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + +class ChatModel: + r"""General class for chat models. Backed by huggingface or vllm engines. + + Supports both sync and async methods. + Sync methods: chat(), stream_chat() and get_scores(). + Async methods: achat(), astream_chat() and aget_scores(). + """ + + def __init__(self, args: Optional[dict[str, Any]] = None) -> None: + model_args, data_args, finetuning_args, generating_args = get_infer_args(args) + + if model_args.infer_backend == EngineName.HF: + from .hf_engine import HuggingfaceEngine + + self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) + elif model_args.infer_backend == EngineName.VLLM: + try: + from .vllm_engine import VllmEngine + + self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args) + except ImportError as e: + raise ImportError( + "vLLM not install, you may need to run `pip install vllm`\n" + "or try to use HuggingFace backend: --infer_backend huggingface" + ) from e + elif model_args.infer_backend == EngineName.SGLANG: + try: + from .sglang_engine import SGLangEngine + + self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args) + except ImportError as e: + raise ImportError( + "SGLang not install, you may need to run `pip install sglang[all]`\n" + "or try to use HuggingFace backend: --infer_backend huggingface" + ) from e + elif model_args.infer_backend == EngineName.KT: + try: + from .kt_engine import KTransformersEngine + + self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args) + except ImportError as e: + raise ImportError( + "KTransformers not install, you may need to run `pip install ktransformers`\n" + "or try to use HuggingFace backend: --infer_backend huggingface" + ) from e + else: + raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") + + self._loop = asyncio.new_event_loop() + self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) + self._thread.start() + + def chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + r"""Get a list of responses of the chat model.""" + task = asyncio.run_coroutine_threadsafe( + self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop + ) + return task.result() + + async def achat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + r"""Asynchronously get a list of responses of the chat model.""" + return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs) + + def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> Generator[str, None, None]: + r"""Get the response token-by-token of the chat model.""" + generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs) + while True: + try: + task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) + yield task.result() + except StopAsyncIteration: + break + + async def astream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + r"""Asynchronously get the response token-by-token of the chat model.""" + async for new_token in self.engine.stream_chat( + messages, system, tools, images, videos, audios, **input_kwargs + ): + yield new_token + + def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + r"""Get a list of scores of the reward model.""" + task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) + return task.result() + + async def aget_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + r"""Asynchronously get a list of scores of the reward model.""" + return await self.engine.get_scores(batch_input, **input_kwargs) + + +def run_chat() -> None: + if os.name != "nt": + try: + import readline # noqa: F401 + except ImportError: + print("Install `readline` for a better experience.") + + chat_model = ChatModel() + messages = [] + print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") + + while True: + try: + query = input("\nUser: ") + except UnicodeDecodeError: + print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") + continue + except Exception: + raise + + if query.strip() == "exit": + break + + if query.strip() == "clear": + messages = [] + torch_gc() + print("History has been removed.") + continue + + messages.append({"role": "user", "content": query}) + print("Assistant: ", end="", flush=True) + + response = "" + for new_text in chat_model.stream_chat(messages): + print(new_text, end="", flush=True) + response += new_text + print() + messages.append({"role": "assistant", "content": response}) diff --git a/LlamaFactory/src/llamafactory/chat/hf_engine.py b/LlamaFactory/src/llamafactory/chat/hf_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..1e670b92c99e3d50362184fdd690cd372fe033d6 --- /dev/null +++ b/LlamaFactory/src/llamafactory/chat/hf_engine.py @@ -0,0 +1,412 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from collections.abc import AsyncGenerator, Callable +from threading import Thread +from typing import TYPE_CHECKING, Any, Optional, Union + +import torch +from transformers import GenerationConfig, TextIteratorStreamer +from typing_extensions import override + +from ..data import get_template_and_fix_tokenizer +from ..extras import logging +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName +from ..model import load_model, load_tokenizer +from .base_engine import BaseEngine, Response + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin + from trl import PreTrainedModelWrapper + + from ..data import Template + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +logger = logging.get_logger(__name__) + + +class HuggingfaceEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.name = EngineName.HF + self.can_generate = finetuning_args.stage == "sft" + tokenizer_module = load_tokenizer(model_args) + self.tokenizer = tokenizer_module["tokenizer"] + self.processor = tokenizer_module["processor"] + self.tokenizer.padding_side = "left" if self.can_generate else "right" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) + self.model = load_model( + self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) + ) # must after fixing tokenizer to resize vocab + self.generating_args = generating_args.to_dict() + try: + asyncio.get_event_loop() + except RuntimeError: + logger.warning_rank0_once("There is no current event loop, creating a new one.") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1"))) + + @staticmethod + def _process_args( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + template: "Template", + generating_args: dict[str, Any], + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> tuple[dict[str, Any], int]: + mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} + if images is not None: + mm_input_dict.update({"images": images, "imglens": [len(images)]}) + if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] + + if videos is not None: + mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]}) + if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] + + if audios is not None: + mm_input_dict.update({"audios": audios, "audlens": [len(audios)]}) + if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] + + messages = template.mm_plugin.process_messages( + messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor + ) + paired_messages = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools) + prompt_ids, _ = template.mm_plugin.process_token_ids( + prompt_ids, + None, + mm_input_dict["images"], + mm_input_dict["videos"], + mm_input_dict["audios"], + tokenizer, + processor, + ) + prompt_length = len(prompt_ids) + inputs = torch.tensor([prompt_ids], device=model.device) + attention_mask = torch.ones_like(inputs, dtype=torch.long) + + do_sample: Optional[bool] = input_kwargs.pop("do_sample", None) + temperature: Optional[float] = input_kwargs.pop("temperature", None) + top_p: Optional[float] = input_kwargs.pop("top_p", None) + top_k: Optional[float] = input_kwargs.pop("top_k", None) + num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) + length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) + skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None) + + if stop is not None: + logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.") + + generating_args = generating_args.copy() + generating_args.update( + dict( + do_sample=do_sample if do_sample is not None else generating_args["do_sample"], + temperature=temperature if temperature is not None else generating_args["temperature"], + top_p=top_p if top_p is not None else generating_args["top_p"], + top_k=top_k if top_k is not None else generating_args["top_k"], + num_return_sequences=num_return_sequences, + repetition_penalty=repetition_penalty + if repetition_penalty is not None + else generating_args["repetition_penalty"], + length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"], + skip_special_tokens=skip_special_tokens + if skip_special_tokens is not None + else generating_args["skip_special_tokens"], + eos_token_id=template.get_stop_token_ids(tokenizer), + pad_token_id=tokenizer.pad_token_id, + ) + ) + + if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0 + generating_args["do_sample"] = True + generating_args["temperature"] = generating_args["temperature"] or 1.0 + + if not generating_args["temperature"]: + generating_args["do_sample"] = False + + if not generating_args["do_sample"]: + generating_args.pop("temperature", None) + generating_args.pop("top_p", None) + + if max_length: + generating_args.pop("max_new_tokens", None) + generating_args["max_length"] = max_length + + if max_new_tokens: + generating_args.pop("max_length", None) + generating_args["max_new_tokens"] = max_new_tokens + + gen_kwargs = dict( + inputs=inputs, + attention_mask=attention_mask, + generation_config=GenerationConfig(**generating_args), + ) + + mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor) + for key, value in mm_inputs.items(): + if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs + value = torch.stack(value) # assume they have same sizes + elif ( + isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor) + ): # for minicpmv inputs + value = torch.stack([torch.stack(v) for v in value]) + elif not isinstance(value, torch.Tensor): + value = torch.tensor(value) + + if torch.is_floating_point(value): # cast data dtype for paligemma + value = value.to(model.dtype) + + if key == "second_per_grid_ts": # qwen2.5vl special case + gen_kwargs[key] = value.tolist() + else: + gen_kwargs[key] = value.to(model.device) + + if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]: + gen_kwargs["input_ids"] = inputs + gen_kwargs["tokenizer"] = tokenizer + if "audio_feature_lens" in mm_inputs: + gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"] + + gen_kwargs.pop("image_sizes", None) + + return gen_kwargs, prompt_length + + @staticmethod + @torch.inference_mode() + def _chat( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + template: "Template", + generating_args: dict[str, Any], + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> list["Response"]: + gen_kwargs, prompt_length = HuggingfaceEngine._process_args( + model, + tokenizer, + processor, + template, + generating_args, + messages, + system, + tools, + images, + videos, + audios, + input_kwargs, + ) + generate_output = model.generate(**gen_kwargs) + if isinstance(generate_output, tuple): + generate_output = generate_output[1][0] # post-process the minicpm_o output + + response_ids = generate_output[:, prompt_length:] + response = tokenizer.batch_decode( + response_ids, + skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True), + clean_up_tokenization_spaces=True, + ) + results = [] + for i in range(len(response)): + eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero() + response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i]) + results.append( + Response( + response_text=response[i], + response_length=response_length, + prompt_length=prompt_length, + finish_reason="stop" if len(eos_index) else "length", + ) + ) + + return results + + @staticmethod + @torch.inference_mode() + def _stream_chat( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + template: "Template", + generating_args: dict[str, Any], + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> Callable[[], str]: + gen_kwargs, _ = HuggingfaceEngine._process_args( + model, + tokenizer, + processor, + template, + generating_args, + messages, + system, + tools, + images, + videos, + audios, + input_kwargs, + ) + streamer = TextIteratorStreamer( + tokenizer, + skip_prompt=True, + skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True), + ) + gen_kwargs["streamer"] = streamer + thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) + thread.start() + + def stream(): + try: + return streamer.__next__() + except StopIteration: + raise StopAsyncIteration() + + return stream + + @staticmethod + @torch.inference_mode() + def _get_scores( + model: "PreTrainedModelWrapper", + tokenizer: "PreTrainedTokenizer", + batch_input: list[str], + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> list[float]: + max_length: Optional[int] = input_kwargs.pop("max_length", None) + device = getattr(model.pretrained_model, "device", "cuda") + inputs: dict[str, torch.Tensor] = tokenizer( + batch_input, + padding=True, + truncation=True, + max_length=max_length or getattr(model.config, "max_position_embeddings", 1024), + return_tensors="pt", + add_special_tokens=False, + ).to(device) + values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1] + scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1)) + return scores + + @override + async def chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + if not self.can_generate: + raise ValueError("The current model does not support `chat`.") + + input_args = ( + self.model, + self.tokenizer, + self.processor, + self.template, + self.generating_args, + messages, + system, + tools, + images, + videos, + audios, + input_kwargs, + ) + async with self.semaphore: + return await asyncio.to_thread(self._chat, *input_args) + + @override + async def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + if not self.can_generate: + raise ValueError("The current model does not support `stream_chat`.") + + input_args = ( + self.model, + self.tokenizer, + self.processor, + self.template, + self.generating_args, + messages, + system, + tools, + images, + videos, + audios, + input_kwargs, + ) + async with self.semaphore: + stream = self._stream_chat(*input_args) + while True: + try: + yield await asyncio.to_thread(stream) + except StopAsyncIteration: + break + + @override + async def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + if self.can_generate: + raise ValueError("Cannot get scores using an auto-regressive model.") + + input_args = (self.model, self.tokenizer, batch_input, input_kwargs) + async with self.semaphore: + return await asyncio.to_thread(self._get_scores, *input_args) diff --git a/LlamaFactory/src/llamafactory/chat/kt_engine.py b/LlamaFactory/src/llamafactory/chat/kt_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf3f4bb2b685ee971d538d29f0b6afa16956f2c --- /dev/null +++ b/LlamaFactory/src/llamafactory/chat/kt_engine.py @@ -0,0 +1,284 @@ +# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import platform +from collections.abc import AsyncGenerator +from threading import Thread +from typing import TYPE_CHECKING, Any, Optional + +import torch +from typing_extensions import override + +from ..data import get_template_and_fix_tokenizer +from ..extras import logging +from ..extras.constants import EngineName +from ..model import load_model, load_tokenizer +from .base_engine import BaseEngine, Response + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + from trl import PreTrainedModelWrapper + + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + +from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled +from ktransformers.server.config.config import Config +from ktransformers.util.utils import ( + get_compute_capability, + prefill_and_generate_capture, +) +from ktransformers.util.vendors import GPUVendor, device_manager + + +logger = logging.get_logger(__name__) + + +class KTransformersEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.name = EngineName.KT + self.can_generate = finetuning_args.stage == "sft" + + tok_mod = load_tokenizer(model_args) + self.tokenizer = tok_mod["tokenizer"] + self.tokenizer.padding_side = "left" if self.can_generate else "right" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) + + self.model = load_model( + self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) + ) + + self.generating_args = generating_args.to_dict() + self.max_new_tokens = model_args.kt_maxlen + self.use_cuda_graph = model_args.kt_use_cuda_graph + self.mode = model_args.kt_mode + self.force_think = model_args.kt_force_think + self.chunk_size = model_args.chunk_size + + try: + asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1"))) + + @staticmethod + @torch.inference_mode() + def _get_scores( + model: "PreTrainedModelWrapper", + tokenizer: "PreTrainedTokenizer", + batch_input: list[str], + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> list[float]: + max_length: Optional[int] = input_kwargs.pop("max_length", None) + device = getattr(model.pretrained_model, "device", "cuda") + inputs = tokenizer( + batch_input, + padding=True, + truncation=True, + max_length=max_length or getattr(model.config, "max_position_embeddings", 1024), + return_tensors="pt", + add_special_tokens=False, + ).to(device) + values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1] + scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1)) + return scores + + async def _generate( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + paired = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools) + prompt_len = len(prompt_ids) + + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + + if "max_new_tokens" in self.generating_args: + max_tokens = int(self.generating_args["max_new_tokens"]) + elif "max_length" in self.generating_args: + gl = int(self.generating_args["max_length"]) + max_tokens = gl - prompt_len if gl > prompt_len else 1 + else: + max_tokens = self.max_new_tokens or 256 + + if max_length is not None: + max_tokens = max(max_length - prompt_len, 1) + if max_new_tokens is not None: + max_tokens = int(max_new_tokens) + max_tokens = max(1, int(max_tokens)) + + if self.mode == "long_context": + max_len_cfg = Config().long_context_config["max_seq_len"] + need = prompt_len + max_tokens + assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml" + + device = next(self.model.parameters()).device + input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device) + if self.force_think: + think = torch.tensor( + [self.tokenizer.encode("\n", add_special_tokens=False)], dtype=torch.long, device=device + ) + input_tensor = torch.cat([input_tensor, think], dim=1) + + use_flashinfer = ( + platform.system() != "Windows" + and getattr(self.model.config, "architectures", [""])[0] + in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"} + and flashinfer_enabled + and get_compute_capability() >= 8 + and device_manager.gpu_vendor == GPUVendor.NVIDIA + ) + + def make_gen(): + if use_flashinfer: + return prefill_and_generate_capture( + self.model, + self.tokenizer, + input_tensor, + max_tokens, + self.use_cuda_graph, + mode=self.mode, + force_think=self.force_think, + chunk_size=self.chunk_size, + use_flashinfer_mla=True, + num_heads=self.model.config.num_attention_heads, + head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0), + head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0), + q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0) + + getattr(self.model.config, "qk_nope_head_dim", 0), + echo_stream=False, + ) + else: + return prefill_and_generate_capture( + self.model, + self.tokenizer, + input_tensor, + max_tokens, + self.use_cuda_graph, + mode=self.mode, + force_think=self.force_think, + chunk_size=self.chunk_size, + echo_stream=False, + ) + + loop = asyncio.get_running_loop() + q: asyncio.Queue[Optional[str]] = asyncio.Queue() + + def producer(): + try: + gen = make_gen() + if hasattr(gen, "__aiter__"): + + async def drain_async(): + async for t in gen: + loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t)) + + asyncio.run(drain_async()) + elif hasattr(gen, "__iter__"): + for t in gen: + loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t)) + else: + loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen)) + finally: + loop.call_soon_threadsafe(q.put_nowait, None) + + Thread(target=producer, daemon=True).start() + + while True: + item = await q.get() + if item is None: + break + yield item + + @override + async def chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + if not self.can_generate: + raise ValueError("The current model does not support `chat`.") + async with self.semaphore: + produced = "" + final_text = "" + async for t in self._generate(messages, system, tools, **input_kwargs): + delta = t + produced = produced + delta + if delta: + final_text += delta + + prompt_ids, _ = self.template.encode_oneturn( + self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools + ) + return [ + Response( + response_text=final_text, + response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)), + prompt_length=len(prompt_ids), + finish_reason="stop", + ) + ] + + @override + async def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + if not self.can_generate: + raise ValueError("The current model does not support `stream_chat`.") + async with self.semaphore: + produced = "" + async for t in self._generate(messages, system, tools, **input_kwargs): + delta = t[len(produced) :] if t.startswith(produced) else t + produced = t + if delta: + yield delta + + @override + async def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + if self.can_generate: + raise ValueError("Cannot get scores using an auto-regressive model.") + args = (self.model, self.tokenizer, batch_input, input_kwargs) + async with self.semaphore: + return await asyncio.to_thread(self._get_scores, *args) diff --git a/LlamaFactory/src/llamafactory/chat/sglang_engine.py b/LlamaFactory/src/llamafactory/chat/sglang_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d2ead33823bc70d51cda59750d25580f972083 --- /dev/null +++ b/LlamaFactory/src/llamafactory/chat/sglang_engine.py @@ -0,0 +1,289 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import atexit +import json +from collections.abc import AsyncGenerator, AsyncIterator, Sequence +from typing import TYPE_CHECKING, Any, Optional, Union + +import requests +from typing_extensions import override + +from ..data import get_template_and_fix_tokenizer +from ..extras import logging +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName +from ..extras.misc import get_device_count, torch_gc +from ..extras.packages import is_sglang_available +from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments +from ..model import load_config, load_tokenizer +from ..model.model_utils.quantization import QuantizationMethod +from .base_engine import BaseEngine, Response + + +if is_sglang_available(): + from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore + + +if TYPE_CHECKING: + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + + +logger = logging.get_logger(__name__) + + +class SGLangEngine(BaseEngine): + """Inference engine for SGLang models. + + This class wraps the SGLang engine to provide a consistent interface for text generation + that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for + better interaction and performance. The engine launches a server process and communicates + with it via HTTP requests. + + For more details on the SGLang HTTP server approach, see: + https://docs.sglang.ai/backend/send_request.html + """ + + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.name = EngineName.SGLANG + self.model_args = model_args + config = load_config(model_args) # may download model from ms hub + if getattr(config, "quantization_config", None): # gptq models should use float16 + quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": + model_args.infer_dtype = "float16" + + self.can_generate = finetuning_args.stage == "sft" + tokenizer_module = load_tokenizer(model_args) + self.tokenizer = tokenizer_module["tokenizer"] + self.processor = tokenizer_module["processor"] + self.tokenizer.padding_side = "left" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) + self.template.mm_plugin.expand_mm_tokens = False # for sglang generate + self.generating_args = generating_args.to_dict() + if model_args.adapter_name_or_path is not None: + self.lora_request = True + else: + self.lora_request = False + + launch_cmd = [ + "python3 -m sglang.launch_server", + f"--model-path {model_args.model_name_or_path}", + f"--dtype {model_args.infer_dtype}", + f"--context-length {model_args.sglang_maxlen}", + f"--mem-fraction-static {model_args.sglang_mem_fraction}", + f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}", + f"--download-dir {model_args.cache_dir}", + "--log-level error", + ] + if self.lora_request: + launch_cmd.extend( + [ + "--max-loras-per-batch 1", + f"--lora-backend {model_args.sglang_lora_backend}", + f"--lora-paths lora0={model_args.adapter_name_or_path[0]}", + "--disable-radix-cache", + ] + ) + launch_cmd = " ".join(launch_cmd) + logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}") + try: + torch_gc() + self.server_process, port = launch_server_cmd(launch_cmd) + self.base_url = f"http://localhost:{port}" + atexit.register(self._cleanup_server) + + logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}") + wait_for_server(self.base_url, timeout=300) + logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}") + try: + response = requests.get(f"{self.base_url}/get_model_info", timeout=5) + if response.status_code == 200: + model_info = response.json() + logger.info(f"SGLang server model info: {model_info}") + except Exception as e: + logger.debug(f"Note: could not get model info: {str(e)}") + + except Exception as e: + logger.error(f"Failed to start SGLang server: {str(e)}") + self._cleanup_server() # make sure to clean up any started process + raise RuntimeError(f"SGLang server initialization failed: {str(e)}.") + + def _cleanup_server(self): + r"""Clean up the server process when the engine is destroyed.""" + if hasattr(self, "server_process") and self.server_process: + try: + logger.info("Terminating SGLang server process") + terminate_process(self.server_process) + logger.info("SGLang server process terminated") + except Exception as e: + logger.warning(f"Error terminating SGLang server: {str(e)}") + + async def _generate( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncIterator[dict[str, Any]]: + if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] + + if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] + + if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] + + messages = self.template.mm_plugin.process_messages( + messages, images or [], videos or [], audios or [], self.processor + ) + paired_messages = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) + prompt_length = len(prompt_ids) + + temperature: Optional[float] = input_kwargs.pop("temperature", None) + top_p: Optional[float] = input_kwargs.pop("top_p", None) + top_k: Optional[float] = input_kwargs.pop("top_k", None) + num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) + skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None) + + if num_return_sequences != 1: + raise NotImplementedError("SGLang only supports n=1.") + + if "max_new_tokens" in self.generating_args: + max_tokens = self.generating_args["max_new_tokens"] + elif "max_length" in self.generating_args: + if self.generating_args["max_length"] > prompt_length: + max_tokens = self.generating_args["max_length"] - prompt_length + else: + max_tokens = 1 + + if max_length: + max_tokens = max_length - prompt_length if max_length > prompt_length else 1 + + if max_new_tokens: + max_tokens = max_new_tokens + + sampling_params = { + "temperature": temperature if temperature is not None else self.generating_args["temperature"], + "top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 + "top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0 + "stop": stop, + "stop_token_ids": self.template.get_stop_token_ids(self.tokenizer), + "max_new_tokens": max_tokens, + "repetition_penalty": ( + repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"] + ) + or 1.0, # repetition_penalty must > 0 + "skip_special_tokens": skip_special_tokens + if skip_special_tokens is not None + else self.generating_args["skip_special_tokens"], + } + + def stream_request(): + json_data = { + "input_ids": prompt_ids, + "sampling_params": sampling_params, + "stream": True, + } + if self.lora_request: + json_data["lora_request"] = ["lora0"] + response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True) + if response.status_code != 200: + raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}") + + for chunk in response.iter_lines(decode_unicode=False): + chunk = str(chunk.decode("utf-8")) + if chunk == "data: [DONE]": + break + + if chunk and chunk.startswith("data:"): + yield json.loads(chunk[5:].strip("\n")) + + return await asyncio.to_thread(stream_request) + + @override + async def chat( + self, + messages: Sequence[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, + audios: Optional[Sequence["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + final_output = None + generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) + for request_output in generator: + final_output = request_output + + results = [ + Response( + response_text=final_output["text"], + response_length=final_output["meta_info"]["completion_tokens"], + prompt_length=final_output["meta_info"]["prompt_tokens"], + finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length", + ) + ] + return results + + @override + async def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + generated_text = "" + generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) + for result in generator: + delta_text = result["text"][len(generated_text) :] + generated_text = result["text"] + yield delta_text + + @override + async def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + raise NotImplementedError("SGLang engine does not support `get_scores`.") + + def __del__(self): + r"""Ensure server is cleaned up when object is deleted.""" + self._cleanup_server() + try: + atexit.unregister(self._cleanup_server) + except Exception: + pass diff --git a/LlamaFactory/src/llamafactory/chat/vllm_engine.py b/LlamaFactory/src/llamafactory/chat/vllm_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..075924a2fdb6c12c942b5704bc5ffd49d92808a4 --- /dev/null +++ b/LlamaFactory/src/llamafactory/chat/vllm_engine.py @@ -0,0 +1,271 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from collections.abc import AsyncGenerator, AsyncIterator +from typing import TYPE_CHECKING, Any, Optional, Union + +from packaging import version +from typing_extensions import override + +from ..data import get_template_and_fix_tokenizer +from ..extras import logging +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName +from ..extras.misc import get_device_count +from ..extras.packages import is_vllm_available +from ..model import load_config, load_tokenizer +from ..model.model_utils.quantization import QuantizationMethod +from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM +from .base_engine import BaseEngine, Response + + +if is_vllm_available(): + from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams + from vllm.lora.request import LoRARequest + + +if TYPE_CHECKING: + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +logger = logging.get_logger(__name__) + + +class VllmEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.name = EngineName.VLLM + self.model_args = model_args + config = load_config(model_args) # may download model from ms hub + if getattr(config, "quantization_config", None): # gptq models should use float16 + quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": + model_args.infer_dtype = "float16" + + self.can_generate = finetuning_args.stage == "sft" + tokenizer_module = load_tokenizer(model_args) + self.tokenizer = tokenizer_module["tokenizer"] + self.processor = tokenizer_module["processor"] + self.tokenizer.padding_side = "left" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) + self.template.mm_plugin.expand_mm_tokens = False # for vllm generate + self.generating_args = generating_args.to_dict() + + engine_args = { + "model": model_args.model_name_or_path, + "trust_remote_code": model_args.trust_remote_code, + "download_dir": model_args.cache_dir, + "dtype": model_args.infer_dtype, + "max_model_len": model_args.vllm_maxlen, + "tensor_parallel_size": get_device_count() or 1, + "gpu_memory_utilization": model_args.vllm_gpu_util, + "disable_log_stats": True, + "enforce_eager": model_args.vllm_enforce_eager, + "enable_lora": model_args.adapter_name_or_path is not None, + "max_lora_rank": model_args.vllm_max_lora_rank, + } + + import vllm + + if version.parse(vllm.__version__) <= version.parse("0.10.0"): + engine_args["disable_log_requests"] = True + else: + engine_args["enable_log_requests"] = False + + if self.template.mm_plugin.__class__.__name__ != "BasePlugin": + engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2} + + if isinstance(model_args.vllm_config, dict): + engine_args.update(model_args.vllm_config) + + if getattr(config, "is_yi_vl_derived_model", None): + import vllm.model_executor.models.llava + + logger.info_rank0("Detected Yi-VL model, applying projector patch.") + vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM + + self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) + if model_args.adapter_name_or_path is not None: + self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0]) + else: + self.lora_request = None + + async def _generate( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncIterator["RequestOutput"]: + request_id = f"chatcmpl-{uuid.uuid4().hex}" + if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] + + if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] + + if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] + + messages = self.template.mm_plugin.process_messages( + messages, images or [], videos or [], audios or [], self.processor + ) + paired_messages = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) + prompt_length = len(prompt_ids) + + temperature: Optional[float] = input_kwargs.pop("temperature", None) + top_p: Optional[float] = input_kwargs.pop("top_p", None) + top_k: Optional[float] = input_kwargs.pop("top_k", None) + num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) + length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) + skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None) + + if length_penalty is not None: + logger.warning_rank0("Length penalty is not supported by the vllm engine yet.") + + if "max_new_tokens" in self.generating_args: + max_tokens = self.generating_args["max_new_tokens"] + elif "max_length" in self.generating_args: + if self.generating_args["max_length"] > prompt_length: + max_tokens = self.generating_args["max_length"] - prompt_length + else: + max_tokens = 1 + + if max_length: + max_tokens = max_length - prompt_length if max_length > prompt_length else 1 + + if max_new_tokens: + max_tokens = max_new_tokens + + sampling_params = SamplingParams( + n=num_return_sequences, + repetition_penalty=( + repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"] + ) + or 1.0, # repetition_penalty must > 0 + temperature=temperature if temperature is not None else self.generating_args["temperature"], + top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 + top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0 + stop=stop, + stop_token_ids=self.template.get_stop_token_ids(self.tokenizer), + max_tokens=max_tokens, + skip_special_tokens=skip_special_tokens + if skip_special_tokens is not None + else self.generating_args["skip_special_tokens"], + ) + + if images is not None: # add image features + multi_modal_data = { + "image": self.template.mm_plugin._regularize_images( + images, + image_max_pixels=self.model_args.image_max_pixels, + image_min_pixels=self.model_args.image_min_pixels, + )["images"] + } + elif videos is not None: + multi_modal_data = { + "video": self.template.mm_plugin._regularize_videos( + videos, + image_max_pixels=self.model_args.video_max_pixels, + image_min_pixels=self.model_args.video_min_pixels, + video_fps=self.model_args.video_fps, + video_maxlen=self.model_args.video_maxlen, + )["videos"] + } + elif audios is not None: + audio_data = self.template.mm_plugin._regularize_audios( + audios, + sampling_rate=self.model_args.audio_sampling_rate, + ) + multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} + else: + multi_modal_data = None + + result_generator = self.model.generate( + {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, + sampling_params=sampling_params, + request_id=request_id, + lora_request=self.lora_request, + ) + return result_generator + + @override + async def chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + final_output = None + generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) + async for request_output in generator: + final_output = request_output + + results = [] + for output in final_output.outputs: + results.append( + Response( + response_text=output.text, + response_length=len(output.token_ids), + prompt_length=len(final_output.prompt_token_ids), + finish_reason=output.finish_reason, + ) + ) + + return results + + @override + async def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + generated_text = "" + generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) + async for result in generator: + delta_text = result.outputs[0].text[len(generated_text) :] + generated_text = result.outputs[0].text + yield delta_text + + @override + async def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + raise NotImplementedError("vLLM engine does not support `get_scores`.") diff --git a/LlamaFactory/src/llamafactory/cli.py b/LlamaFactory/src/llamafactory/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..d574bf1db543f5379f074e276898826234708037 --- /dev/null +++ b/LlamaFactory/src/llamafactory/cli.py @@ -0,0 +1,31 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def main(): + from .extras.misc import is_env_enabled + + if is_env_enabled("USE_V1"): + from .v1 import launcher + else: + from . import launcher + + launcher.launch() + + +if __name__ == "__main__": + from multiprocessing import freeze_support + + freeze_support() + main() diff --git a/LlamaFactory/src/llamafactory/data/.ipynb_checkpoints/template-checkpoint.py b/LlamaFactory/src/llamafactory/data/.ipynb_checkpoints/template-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e83b2e90a60c439370ed5cad4a13846abc977bf5 --- /dev/null +++ b/LlamaFactory/src/llamafactory/data/.ipynb_checkpoints/template-checkpoint.py @@ -0,0 +1,2175 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from copy import deepcopy +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +from typing_extensions import override + +from ..extras import logging +from .data_utils import Role +from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter +from .mm_plugin import get_mm_plugin + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from ..hparams import DataArguments + from .formatter import SLOTS, Formatter + from .mm_plugin import BasePlugin + from .tool_utils import FunctionCall + + +logger = logging.get_logger(__name__) + + +@dataclass +class Template: + format_user: "Formatter" + format_assistant: "Formatter" + format_system: "Formatter" + format_function: "Formatter" + format_observation: "Formatter" + format_tools: "Formatter" + format_prefix: "Formatter" + default_system: str + stop_words: list[str] + thought_words: tuple[str, str] + tool_call_words: tuple[str, str] + efficient_eos: bool + replace_eos: bool + replace_jinja_template: bool + enable_thinking: Optional[bool] + mm_plugin: "BasePlugin" + + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> tuple[list[int], list[int]]: + r"""Return a single pair of token ids representing prompt and response respectively.""" + encoded_messages = self._encode(tokenizer, messages, system, tools) + prompt_ids = [] + for encoded_ids in encoded_messages[:-1]: + prompt_ids += encoded_ids + + response_ids = encoded_messages[-1] + return prompt_ids, response_ids + + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> list[tuple[list[int], list[int]]]: + r"""Return multiple pairs of token ids representing prompts and responses respectively.""" + encoded_messages = self._encode(tokenizer, messages, system, tools) + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + + def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]: + r"""Extract tool message.""" + return self.format_tools.extract(content) + + def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]: + r"""Return stop token ids.""" + stop_token_ids = {tokenizer.eos_token_id} + for token in self.stop_words: + stop_token_ids.add(tokenizer.convert_tokens_to_ids(token)) + + return list(stop_token_ids) + + def add_thought(self, content: str = "") -> str: + r"""Add empty thought to assistant message.""" + return f"{self.thought_words[0]}{self.thought_words[1]}" + content + + def remove_thought(self, content: str) -> str: + r"""Remove thought from assistant message.""" + pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL) + return re.sub(pattern, "", content).lstrip("\n") + + def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]: + r"""Get the token ids of thought words.""" + return tokenizer.encode(self.add_thought(), add_special_tokens=False) + + def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]: + r"""Convert elements to token ids.""" + token_ids = [] + for elem in elements: + if isinstance(elem, str): + if len(elem) != 0: + token_ids += tokenizer.encode(elem, add_special_tokens=False) + elif isinstance(elem, dict): + token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))] + elif isinstance(elem, set): + if "bos_token" in elem and tokenizer.bos_token_id is not None: + token_ids += [tokenizer.bos_token_id] + elif "eos_token" in elem and tokenizer.eos_token_id is not None: + token_ids += [tokenizer.eos_token_id] + else: + raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}") + + return token_ids + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str], + tools: Optional[str], + ) -> list[list[int]]: + r"""Encode formatted inputs to pairs of token ids. + + Turn 0: prefix + system + query resp + Turn t: query resp. + """ + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + elements += self.format_system.apply(content=(system + tool_text)) + + if message["role"] == Role.USER: + elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) + elif message["role"] == Role.ASSISTANT: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION: + elements += self.format_function.apply( + content=message["content"], thought_words=self.thought_words, tool_call_words=self.tool_call_words + ) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return encoded_messages + + @staticmethod + def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: + r"""Add or replace eos token to the tokenizer.""" + if tokenizer.eos_token == eos_token: + return + + is_added = tokenizer.eos_token_id is None + num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) + + if is_added: + logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.") + else: + logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.") + + if num_added_tokens > 0: + logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") + + def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None: + r"""Add eos token and pad token to the tokenizer.""" + stop_words = self.stop_words + if self.replace_eos: + if not stop_words: + raise ValueError("Stop words are required to replace the EOS token.") + + self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0]) + stop_words = stop_words[1:] + + if tokenizer.eos_token_id is None: + self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>") + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info_rank0(f"Add pad token: {tokenizer.pad_token}") + + if stop_words: + try: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False + ) + except TypeError: + num_added_tokens = tokenizer.add_special_tokens(dict(additional_special_tokens=stop_words)) + logger.info_rank0("Add {} to stop words.".format(",".join(stop_words))) + if num_added_tokens > 0: + logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") + + @staticmethod + def _jinja_escape(content: str) -> str: + r"""Escape single quotes in content.""" + return content.replace("'", r"\'") + + @staticmethod + def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: + r"""Convert slots to jinja template.""" + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'") + if len(slot_pieces) > 1: + slot_items.append(placeholder) + if slot_pieces[1]: + slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'") + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: + slot_items.append("'" + tokenizer.bos_token + "'") + elif "eos_token" in slot and tokenizer.eos_token_id is not None: + slot_items.append("'" + tokenizer.eos_token + "'") + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return " + ".join(slot_items) + + def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str: + r"""Return the jinja template.""" + prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) + system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message") + user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) + assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer) + jinja_template = "" + if prefix: + jinja_template += "{{ " + prefix + " }}" + + if self.default_system: + jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}" + + jinja_template += ( + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" + "{% if system_message is defined %}{{ " + system + " }}{% endif %}" + "{% for message in loop_messages %}" + "{% set content = message['content'] %}" + "{% if message['role'] == 'user' %}" + "{{ " + user + " }}" + "{% elif message['role'] == 'assistant' %}" + "{{ " + assistant + " }}" + "{% endif %}" + "{% endfor %}" + ) + return jinja_template + + def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None: + r"""Replace the jinja template in the tokenizer.""" + if tokenizer.chat_template is None or self.replace_jinja_template: + try: + tokenizer.chat_template = self._get_jinja_template(tokenizer) + except ValueError as e: + logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.") + + @staticmethod + def _convert_slots_to_ollama( + slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content" + ) -> str: + r"""Convert slots to ollama template.""" + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append(slot_pieces[0]) + if len(slot_pieces) > 1: + slot_items.append("{{ " + placeholder + " }}") + if slot_pieces[1]: + slot_items.append(slot_pieces[1]) + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: + slot_items.append(tokenizer.bos_token) + elif "eos_token" in slot and tokenizer.eos_token_id is not None: + slot_items.append(tokenizer.eos_token) + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return "".join(slot_items) + + def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str: + r"""Return the ollama template.""" + prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer) + system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System") + user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content") + assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content") + return ( + f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}" + f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}""" + f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}""" + ) + + def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str: + r"""Return the ollama modelfile. + + TODO: support function calling. + """ + modelfile = "# ollama modelfile auto-generated by llamafactory\n\n" + modelfile += f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n' + + if self.default_system: + modelfile += f'SYSTEM """{self.default_system}"""\n\n' + + for stop_token_id in self.get_stop_token_ids(tokenizer): + modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n' + + modelfile += "PARAMETER num_ctx 4096\n" + return modelfile + + +@dataclass +class Llama2Template(Template): + r"""A template that fuse the system message to first user message.""" + + @override + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: str, + tools: str, + ) -> list[list[int]]: + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + + system_text = "" + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + system_text = self.format_system.apply(content=(system + tool_text))[0] + + if message["role"] == Role.USER: + elements += self.format_user.apply(content=system_text + message["content"]) + elif message["role"] == Role.ASSISTANT: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION: + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return encoded_messages + + def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str: + prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) + system_message = self._convert_slots_to_jinja( + self.format_system.apply(), tokenizer, placeholder="system_message" + ) + user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) + assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer) + jinja_template = "" + if prefix: + jinja_template += "{{ " + prefix + " }}" + + if self.default_system: + jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}" + + jinja_template += ( + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" + "{% for message in loop_messages %}" + "{% if loop.index0 == 0 and system_message is defined %}" + "{% set content = " + system_message + " + message['content'] %}" + "{% else %}{% set content = message['content'] %}{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ " + user_message + " }}" + "{% elif message['role'] == 'assistant' %}" + "{{ " + assistant_message + " }}" + "{% endif %}" + "{% endfor %}" + ) + return jinja_template + + +@dataclass +class ReasoningTemplate(Template): + r"""A template that add thought to assistant message.""" + + @override + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> tuple[list[int], list[int]]: + messages = deepcopy(messages) + for i in range(1, len(messages) - 2, 2): + messages[i]["content"] = self.remove_thought(messages[i]["content"]) + + if self.enable_thinking is False: # remove all cot + messages[-1]["content"] = self.remove_thought(messages[-1]["content"]) + + prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools) + if ( + self.thought_words[0].strip() not in messages[-1]["content"] + and self.thought_words[1].strip() not in messages[-1]["content"] + ): # add empty cot + if not self.enable_thinking: # do not compute loss + prompt_ids += self.get_thought_word_ids(tokenizer) + else: # do compute loss + response_ids = self.get_thought_word_ids(tokenizer) + response_ids + + return prompt_ids, response_ids + + @override + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> list[tuple[list[int], list[int]]]: + messages = deepcopy(messages) + if self.enable_thinking is False: # remove all cot + for i in range(1, len(messages), 2): + messages[i]["content"] = self.remove_thought(messages[i]["content"]) + + encoded_messages = self._encode(tokenizer, messages, system, tools) + for i in range(0, len(messages), 2): + if ( + self.thought_words[0].strip() not in messages[i + 1]["content"] + and self.thought_words[1].strip() not in messages[i + 1]["content"] + ): # add empty cot + if not self.enable_thinking: # do not compute loss + encoded_messages[i] += self.get_thought_word_ids(tokenizer) + else: # do compute loss + encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1] + + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + + +TEMPLATES: dict[str, "Template"] = {} + + +def register_template( + name: str, + format_user: Optional["Formatter"] = None, + format_assistant: Optional["Formatter"] = None, + format_system: Optional["Formatter"] = None, + format_function: Optional["Formatter"] = None, + format_observation: Optional["Formatter"] = None, + format_tools: Optional["Formatter"] = None, + format_prefix: Optional["Formatter"] = None, + default_system: str = "", + stop_words: Optional[list[str]] = None, + thought_words: Optional[tuple[str, str]] = None, + tool_call_words: Optional[tuple[str, str]] = None, + efficient_eos: bool = False, + replace_eos: bool = False, + replace_jinja_template: bool = False, + enable_thinking: Optional[bool] = True, + mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), + template_class: type["Template"] = Template, +) -> None: + r"""Register a chat template. + + To add the following chat template: + ``` + user prompt here + model response here + user prompt here + model response here + ``` + + The corresponding code should be: + ``` + register_template( + name="custom", + format_user=StringFormatter(slots=["{{content}}\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_prefix=EmptyFormatter(""), + ) + ``` + """ + if name in TEMPLATES: + raise ValueError(f"Template {name} already exists.") + + default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] + default_user_formatter = StringFormatter(slots=["{{content}}"]) + default_assistant_formatter = StringFormatter(slots=default_slots) + if format_assistant is not None: + default_function_formatter = FunctionFormatter(slots=format_assistant.slots, tool_format="default") + else: + default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default") + + default_tool_formatter = ToolFormatter(tool_format="default") + default_prefix_formatter = EmptyFormatter() + TEMPLATES[name] = template_class( + format_user=format_user or default_user_formatter, + format_assistant=format_assistant or default_assistant_formatter, + format_system=format_system or default_user_formatter, + format_function=format_function or default_function_formatter, + format_observation=format_observation or format_user or default_user_formatter, + format_tools=format_tools or default_tool_formatter, + format_prefix=format_prefix or default_prefix_formatter, + default_system=default_system, + stop_words=stop_words or [], + thought_words=thought_words or ("\n", "\n\n\n"), + tool_call_words=tool_call_words or ("", ""), + efficient_eos=efficient_eos, + replace_eos=replace_eos, + replace_jinja_template=replace_jinja_template, + enable_thinking=enable_thinking, + mm_plugin=mm_plugin, + ) + + +def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": + r"""Extract a chat template from the tokenizer.""" + + def find_diff(short_str: str, long_str: str) -> str: + i, j = 0, 0 + diff = "" + while i < len(short_str) and j < len(long_str): + if short_str[i] == long_str[j]: + i += 1 + j += 1 + else: + diff += long_str[j] + j += 1 + + return diff + + prefix = tokenizer.decode(tokenizer.encode("")) + + messages = [{"role": "system", "content": "{{content}}"}] + system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :] + + messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}] + user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + user_slot_empty_system = user_slot_empty_system[len(prefix) :] + + messages = [{"role": "user", "content": "{{content}}"}] + user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + user_slot = user_slot[len(prefix) :] + + messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}] + assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) + assistant_slot = assistant_slot[len(prefix) + len(user_slot) :] + template_class = ReasoningTemplate if "" in assistant_slot else Template + assistant_slot = assistant_slot.replace("", "").replace("", "").lstrip("\n") # remove thought tags + + if len(user_slot) > len(user_slot_empty_system): + default_system = find_diff(user_slot_empty_system, user_slot) + sole_system = system_slot.replace("{{content}}", default_system, 1) + user_slot = user_slot[len(sole_system) :] + else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot + default_system = "" + + return template_class( + format_user=StringFormatter(slots=[user_slot]), + format_assistant=StringFormatter(slots=[assistant_slot]), + format_system=StringFormatter(slots=[system_slot]), + format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"), + format_observation=StringFormatter(slots=[user_slot]), + format_tools=ToolFormatter(tool_format="default"), + format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(), + default_system=default_system, + stop_words=[], + thought_words=("\n", "\n\n\n"), + tool_call_words=("", ""), + efficient_eos=False, + replace_eos=False, + replace_jinja_template=False, + enable_thinking=True, + mm_plugin=get_mm_plugin(name="base"), + ) + + +def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template": + r"""Get chat template and fixes the tokenizer.""" + if data_args.template is None: + if isinstance(tokenizer.chat_template, str): + logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.") + template = parse_template(tokenizer) + else: + logger.warning_rank0("`template` was not specified, use `empty` template.") + template = TEMPLATES["empty"] # placeholder + else: + if data_args.template not in TEMPLATES: + raise ValueError(f"Template {data_args.template} does not exist.") + + template = TEMPLATES[data_args.template] + + if data_args.train_on_prompt and template.efficient_eos: + raise ValueError("Current template does not support `train_on_prompt`.") + + if data_args.tool_format is not None: + logger.info_rank0(f"Using tool format: {data_args.tool_format}.") + default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}] + template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format) + template.format_tools = ToolFormatter(tool_format=data_args.tool_format) + + if data_args.default_system is not None: + logger.info_rank0(f"Using default system message: {data_args.default_system}.") + template.default_system = data_args.default_system + + if isinstance(template, ReasoningTemplate): + logger.warning_rank0( + "You are using reasoning template, " + "please add `_nothink` suffix if the model is not a reasoning model. " + "e.g., qwen3_vl_nothink" + ) + template.enable_thinking = data_args.enable_thinking + + template.fix_special_tokens(tokenizer) + template.fix_jinja_template(tokenizer) + return template + + +register_template( + name="alpaca", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), + default_system=( + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" + ), + replace_jinja_template=True, +) + + +register_template( + name="bailing", + format_user=StringFormatter(slots=["HUMAN{{content}}ASSISTANT"]), + format_system=StringFormatter(slots=["SYSTEM{{content}}"]), + format_observation=StringFormatter(slots=["OBSERVATION{{content}}ASSISTANT"]), + stop_words=["<|endoftext|>"], + efficient_eos=True, +) + + +register_template( + name="bailing_v2", + format_user=StringFormatter(slots=["HUMAN{{content}}<|role_end|>ASSISTANT"]), + format_system=StringFormatter(slots=["SYSTEM{{content}}<|role_end|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|role_end|>"]), + format_observation=StringFormatter( + slots=[ + "OBSERVATION\n\n{{content}}\n<|role_end|>ASSISTANT" + ] + ), + format_function=FunctionFormatter(slots=["{{content}}<|role_end|>"], tool_format="ling"), + format_tools=ToolFormatter(tool_format="ling"), + stop_words=["<|endoftext|>"], + efficient_eos=True, +) + + +register_template( + name="breeze", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + efficient_eos=True, +) + + +register_template( + name="chatglm3", + format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), + format_assistant=StringFormatter(slots=["\n", "{{content}}"]), + format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter( + slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] + ), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +register_template( + name="chatml", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, + replace_jinja_template=True, +) + + +# copied from chatml template +register_template( + name="chatml_de", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, + replace_jinja_template=True, +) + + +register_template( + name="cohere", + format_user=StringFormatter( + slots=[ + ( + "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>" + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + ) + ] + ), + format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +# copied from chatml template +register_template( + name="cpm4", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|im_end|>"], +) + + +# copied from chatml template +register_template( + name="dbrx", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + default_system=( + "You are DBRX, created by Databricks. You were last updated in December 2023. " + "You answer questions based on information available up to that point.\n" + "YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough " + "responses to more complex and open-ended questions.\nYou assist with various tasks, " + "from writing to coding (using markdown for code blocks — remember to use ``` with " + "code, JSON, and tables).\n(You do not have real-time data access or code execution " + "capabilities. You avoid stereotyping and provide balanced perspectives on " + "controversial topics. You do not provide song lyrics, poems, or news articles and " + "do not divulge details of your training data.)\nThis is your system prompt, " + "guiding your responses. Do not reference it, just respond to the user. If you find " + "yourself talking about this message, stop. You should be responding appropriately " + "and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION " + "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." + ), + stop_words=["<|im_end|>"], + replace_eos=True, +) + + +register_template( + name="deepseek", + format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +register_template( + name="deepseek3", + format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +# copied from deepseek3 template +register_template( + name="deepseekr1", + format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + template_class=ReasoningTemplate, +) + + +register_template( + name="deepseekcoder", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), + format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "You are an AI programming assistant, utilizing the DeepSeek Coder model, " + "developed by DeepSeek Company, and you only answer questions related to computer science. " + "For politically sensitive questions, security and privacy issues, " + "and other non-computer science questions, you will refuse to answer.\n" + ), +) + + +register_template( + name="default", + format_user=StringFormatter(slots=["Human: {{content}}", {"eos_token"}, "\nAssistant:"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), + format_system=StringFormatter(slots=["System: {{content}}", {"eos_token"}, "\n"]), + replace_jinja_template=True, +) + + +register_template( + name="dots_ocr", + format_user=StringFormatter(slots=["<|user|>{{content}}<|endofuser|><|assistant|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|endofassistant|>"]), + format_system=StringFormatter(slots=["<|system|>{{content}}<|endofsystem|>\n"]), + stop_words=["<|endofassistant|>"], + efficient_eos=True, + mm_plugin=get_mm_plugin( + name="qwen2_vl", + image_token="<|imgpad|>", + video_token="<|vidpad|>", + vision_bos_token="<|img|>", + vision_eos_token="<|endofimg|>", + ), +) + + +register_template( + name="empty", + format_assistant=StringFormatter(slots=["{{content}}"]), +) + + +# copied from chatml template +register_template( + name="ernie", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]), + default_system="\nthink_mode=True\n", + stop_words=["<|im_end|>"], +) + + +register_template( + name="ernie_nothink", + format_user=StringFormatter(slots=["User: {{content}}\nAssistant: "]), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_sentence|>"]), + format_system=StringFormatter(slots=["{{content}}\n"]), + format_prefix=EmptyFormatter(slots=["<|begin_of_sentence|>"]), + stop_words=["<|end_of_sentence|>"], +) + + +register_template( + name="ernie_vl", + format_user=StringFormatter(slots=["User: {{content}}"]), + format_assistant=StringFormatter(slots=["\nAssistant: {{content}}<|end_of_sentence|>"]), + format_system=StringFormatter(slots=["{{content}}\n"]), + stop_words=["<|end_of_sentence|>"], + replace_eos=True, + replace_jinja_template=True, + template_class=ReasoningTemplate, + mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"), +) + + +register_template( + name="exaone", + format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), + format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]), +) + + +register_template( + name="falcon", + format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + efficient_eos=True, +) + + +# copied from chatml template +register_template( + name="falcon_h1", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|im_end|>", "<|end_of_text|>"], +) + + +register_template( + name="fewshot", + format_assistant=StringFormatter(slots=["{{content}}\n\n"]), + efficient_eos=True, + replace_jinja_template=True, +) + + +register_template( + name="gemma", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + replace_eos=True, + template_class=Llama2Template, +) + + +# copied from gemma template +register_template( + name="gemma2", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["", ""], + efficient_eos=True, + template_class=Llama2Template, +) + + +# copied from gemma template +register_template( + name="gemma3", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + replace_eos=True, + mm_plugin=get_mm_plugin("gemma3", image_token=""), + template_class=Llama2Template, +) + + +register_template( + name="gemma3n", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + replace_eos=True, + mm_plugin=get_mm_plugin("gemma3n", image_token="", audio_token=""), + template_class=Llama2Template, +) + + +register_template( + name="glm4", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +# copied from glm4 template +register_template( + name="glm4_moe", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4_moe"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + template_class=ReasoningTemplate, +) + + +# copied from glm4 template +register_template( + name="glm4v", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>", ""], + efficient_eos=True, + mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"), + template_class=ReasoningTemplate, +) + + +# copied from glm4 template +register_template( + name="glm4_5v", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4_moe"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>", ""], + efficient_eos=True, + mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"), + template_class=ReasoningTemplate, +) + + +# copied from glm4 template +register_template( + name="glmz1", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + template_class=ReasoningTemplate, +) + + +register_template( + name="gpt_oss", + format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]), + format_assistant=StringFormatter(slots=["{{content}}<|end|>"]), + format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]), + default_system="You are ChatGPT, a large language model trained by OpenAI.", + thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"), + efficient_eos=True, + template_class=ReasoningTemplate, +) + + +register_template( + name="granite3", + format_user=StringFormatter( + slots=[ + "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" + ] + ), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]), + format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]), +) + + +register_template( + name="granite3_vision", + format_user=StringFormatter(slots=["<|user|>\n{{content}}\n<|assistant|>\n"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}\n"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + mm_plugin=get_mm_plugin(name="llava_next", image_token=""), +) + + +register_template( + name="granite4", + format_user=StringFormatter( + slots=[ + "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" + ] + ), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]), + format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|end_of_text|>\n"], tool_format="default"), + format_observation=StringFormatter( + slots=["<|start_of_role|>tool<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="default"), + stop_words=["<|end_of_text|>"], + default_system="You are Granite, developed by IBM. You are a helpful AI assistant.", +) + + +register_template( + name="index", + format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]), + format_system=StringFormatter(slots=["{{content}}"]), + efficient_eos=True, +) + + +register_template( + name="hunyuan", + format_user=StringFormatter(slots=["{{content}}<|extra_0|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|eos|>"]), + format_system=StringFormatter(slots=["{{content}}<|extra_4|>"]), + format_prefix=EmptyFormatter(slots=["<|startoftext|>"]), + stop_words=["<|eos|>"], +) + + +register_template( + name="hunyuan_small", + format_user=StringFormatter(slots=["<|hy_User|>{{content}}<|hy_place▁holder▁no▁8|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|hy_place▁holder▁no▁2|>"]), + format_system=StringFormatter(slots=["{{content}}<|hy_place▁holder▁no▁3|>"]), + format_prefix=EmptyFormatter(slots=["<|hy_begin▁of▁sentence|>"]), + stop_words=["<|hy_place▁holder▁no▁2|>"], +) + + +register_template( + name="intern2", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory " + "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM (书生·浦语) can understand and communicate fluently in the language " + "chosen by the user such as English and 中文." + ), + stop_words=["<|im_end|>"], +) + + +register_template( + name="intern_vl", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。" + ), + stop_words=["<|im_end|>"], + mm_plugin=get_mm_plugin(name="intern_vl", image_token="", video_token=" + ``` + + The corresponding code should be: + ``` + register_template( + name="custom", + format_user=StringFormatter(slots=["{{content}}\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_prefix=EmptyFormatter(""), + ) + ``` + """ + if name in TEMPLATES: + raise ValueError(f"Template {name} already exists.") + + default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] + default_user_formatter = StringFormatter(slots=["{{content}}"]) + default_assistant_formatter = StringFormatter(slots=default_slots) + if format_assistant is not None: + default_function_formatter = FunctionFormatter(slots=format_assistant.slots, tool_format="default") + else: + default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default") + + default_tool_formatter = ToolFormatter(tool_format="default") + default_prefix_formatter = EmptyFormatter() + TEMPLATES[name] = template_class( + format_user=format_user or default_user_formatter, + format_assistant=format_assistant or default_assistant_formatter, + format_system=format_system or default_user_formatter, + format_function=format_function or default_function_formatter, + format_observation=format_observation or format_user or default_user_formatter, + format_tools=format_tools or default_tool_formatter, + format_prefix=format_prefix or default_prefix_formatter, + default_system=default_system, + stop_words=stop_words or [], + thought_words=thought_words or ("\n", "\n\n\n"), + tool_call_words=tool_call_words or ("", ""), + efficient_eos=efficient_eos, + replace_eos=replace_eos, + replace_jinja_template=replace_jinja_template, + enable_thinking=enable_thinking, + mm_plugin=mm_plugin, + ) + + +def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": + r"""Extract a chat template from the tokenizer.""" + + def find_diff(short_str: str, long_str: str) -> str: + i, j = 0, 0 + diff = "" + while i < len(short_str) and j < len(long_str): + if short_str[i] == long_str[j]: + i += 1 + j += 1 + else: + diff += long_str[j] + j += 1 + + return diff + + prefix = tokenizer.decode(tokenizer.encode("")) + + messages = [{"role": "system", "content": "{{content}}"}] + system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :] + + messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}] + user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + user_slot_empty_system = user_slot_empty_system[len(prefix) :] + + messages = [{"role": "user", "content": "{{content}}"}] + user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + user_slot = user_slot[len(prefix) :] + + messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}] + assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) + assistant_slot = assistant_slot[len(prefix) + len(user_slot) :] + template_class = ReasoningTemplate if "" in assistant_slot else Template + assistant_slot = assistant_slot.replace("", "").replace("", "").lstrip("\n") # remove thought tags + + if len(user_slot) > len(user_slot_empty_system): + default_system = find_diff(user_slot_empty_system, user_slot) + sole_system = system_slot.replace("{{content}}", default_system, 1) + user_slot = user_slot[len(sole_system) :] + else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot + default_system = "" + + return template_class( + format_user=StringFormatter(slots=[user_slot]), + format_assistant=StringFormatter(slots=[assistant_slot]), + format_system=StringFormatter(slots=[system_slot]), + format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"), + format_observation=StringFormatter(slots=[user_slot]), + format_tools=ToolFormatter(tool_format="default"), + format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(), + default_system=default_system, + stop_words=[], + thought_words=("\n", "\n\n\n"), + tool_call_words=("", ""), + efficient_eos=False, + replace_eos=False, + replace_jinja_template=False, + enable_thinking=True, + mm_plugin=get_mm_plugin(name="base"), + ) + + +def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template": + r"""Get chat template and fixes the tokenizer.""" + if data_args.template is None: + if isinstance(tokenizer.chat_template, str): + logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.") + template = parse_template(tokenizer) + else: + logger.warning_rank0("`template` was not specified, use `empty` template.") + template = TEMPLATES["empty"] # placeholder + else: + if data_args.template not in TEMPLATES: + raise ValueError(f"Template {data_args.template} does not exist.") + + template = TEMPLATES[data_args.template] + + if data_args.train_on_prompt and template.efficient_eos: + raise ValueError("Current template does not support `train_on_prompt`.") + + if data_args.tool_format is not None: + logger.info_rank0(f"Using tool format: {data_args.tool_format}.") + default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}] + template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format) + template.format_tools = ToolFormatter(tool_format=data_args.tool_format) + + if data_args.default_system is not None: + logger.info_rank0(f"Using default system message: {data_args.default_system}.") + template.default_system = data_args.default_system + + if isinstance(template, ReasoningTemplate): + logger.warning_rank0( + "You are using reasoning template, " + "please add `_nothink` suffix if the model is not a reasoning model. " + "e.g., qwen3_vl_nothink" + ) + template.enable_thinking = data_args.enable_thinking + + template.fix_special_tokens(tokenizer) + template.fix_jinja_template(tokenizer) + return template + + +register_template( + name="alpaca", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), + default_system=( + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" + ), + replace_jinja_template=True, +) + + +register_template( + name="bailing", + format_user=StringFormatter(slots=["HUMAN{{content}}ASSISTANT"]), + format_system=StringFormatter(slots=["SYSTEM{{content}}"]), + format_observation=StringFormatter(slots=["OBSERVATION{{content}}ASSISTANT"]), + stop_words=["<|endoftext|>"], + efficient_eos=True, +) + + +register_template( + name="bailing_v2", + format_user=StringFormatter(slots=["HUMAN{{content}}<|role_end|>ASSISTANT"]), + format_system=StringFormatter(slots=["SYSTEM{{content}}<|role_end|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|role_end|>"]), + format_observation=StringFormatter( + slots=[ + "OBSERVATION\n\n{{content}}\n<|role_end|>ASSISTANT" + ] + ), + format_function=FunctionFormatter(slots=["{{content}}<|role_end|>"], tool_format="ling"), + format_tools=ToolFormatter(tool_format="ling"), + stop_words=["<|endoftext|>"], + efficient_eos=True, +) + + +register_template( + name="breeze", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + efficient_eos=True, +) + + +register_template( + name="chatglm3", + format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), + format_assistant=StringFormatter(slots=["\n", "{{content}}"]), + format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter( + slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] + ), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +register_template( + name="chatml", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, + replace_jinja_template=True, +) + + +# copied from chatml template +register_template( + name="chatml_de", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, + replace_jinja_template=True, +) + + +register_template( + name="cohere", + format_user=StringFormatter( + slots=[ + ( + "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>" + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + ) + ] + ), + format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +# copied from chatml template +register_template( + name="cpm4", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|im_end|>"], +) + + +# copied from chatml template +register_template( + name="dbrx", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + default_system=( + "You are DBRX, created by Databricks. You were last updated in December 2023. " + "You answer questions based on information available up to that point.\n" + "YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough " + "responses to more complex and open-ended questions.\nYou assist with various tasks, " + "from writing to coding (using markdown for code blocks — remember to use ``` with " + "code, JSON, and tables).\n(You do not have real-time data access or code execution " + "capabilities. You avoid stereotyping and provide balanced perspectives on " + "controversial topics. You do not provide song lyrics, poems, or news articles and " + "do not divulge details of your training data.)\nThis is your system prompt, " + "guiding your responses. Do not reference it, just respond to the user. If you find " + "yourself talking about this message, stop. You should be responding appropriately " + "and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION " + "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." + ), + stop_words=["<|im_end|>"], + replace_eos=True, +) + + +register_template( + name="deepseek", + format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +register_template( + name="deepseek3", + format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +# copied from deepseek3 template +register_template( + name="deepseekr1", + format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + template_class=ReasoningTemplate, +) + + +register_template( + name="deepseekcoder", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), + format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "You are an AI programming assistant, utilizing the DeepSeek Coder model, " + "developed by DeepSeek Company, and you only answer questions related to computer science. " + "For politically sensitive questions, security and privacy issues, " + "and other non-computer science questions, you will refuse to answer.\n" + ), +) + + +register_template( + name="default", + format_user=StringFormatter(slots=["Human: {{content}}", {"eos_token"}, "\nAssistant:"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), + format_system=StringFormatter(slots=["System: {{content}}", {"eos_token"}, "\n"]), + replace_jinja_template=True, +) + + +register_template( + name="dots_ocr", + format_user=StringFormatter(slots=["<|user|>{{content}}<|endofuser|><|assistant|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|endofassistant|>"]), + format_system=StringFormatter(slots=["<|system|>{{content}}<|endofsystem|>\n"]), + stop_words=["<|endofassistant|>"], + efficient_eos=True, + mm_plugin=get_mm_plugin( + name="qwen2_vl", + image_token="<|imgpad|>", + video_token="<|vidpad|>", + vision_bos_token="<|img|>", + vision_eos_token="<|endofimg|>", + ), +) + + +register_template( + name="empty", + format_assistant=StringFormatter(slots=["{{content}}"]), +) + + +# copied from chatml template +register_template( + name="ernie", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]), + default_system="\nthink_mode=True\n", + stop_words=["<|im_end|>"], +) + + +register_template( + name="ernie_nothink", + format_user=StringFormatter(slots=["User: {{content}}\nAssistant: "]), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_sentence|>"]), + format_system=StringFormatter(slots=["{{content}}\n"]), + format_prefix=EmptyFormatter(slots=["<|begin_of_sentence|>"]), + stop_words=["<|end_of_sentence|>"], +) + + +register_template( + name="ernie_vl", + format_user=StringFormatter(slots=["User: {{content}}"]), + format_assistant=StringFormatter(slots=["\nAssistant: {{content}}<|end_of_sentence|>"]), + format_system=StringFormatter(slots=["{{content}}\n"]), + stop_words=["<|end_of_sentence|>"], + replace_eos=True, + replace_jinja_template=True, + template_class=ReasoningTemplate, + mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"), +) + + +register_template( + name="exaone", + format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), + format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]), +) + + +register_template( + name="falcon", + format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + efficient_eos=True, +) + + +# copied from chatml template +register_template( + name="falcon_h1", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|im_end|>", "<|end_of_text|>"], +) + + +register_template( + name="fewshot", + format_assistant=StringFormatter(slots=["{{content}}\n\n"]), + efficient_eos=True, + replace_jinja_template=True, +) + + +register_template( + name="gemma", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + replace_eos=True, + template_class=Llama2Template, +) + + +# copied from gemma template +register_template( + name="gemma2", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["", ""], + efficient_eos=True, + template_class=Llama2Template, +) + + +# copied from gemma template +register_template( + name="gemma3", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + replace_eos=True, + mm_plugin=get_mm_plugin("gemma3", image_token=""), + template_class=Llama2Template, +) + + +register_template( + name="gemma3n", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + replace_eos=True, + mm_plugin=get_mm_plugin("gemma3n", image_token="", audio_token=""), + template_class=Llama2Template, +) + + +register_template( + name="glm4", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +# copied from glm4 template +register_template( + name="glm4_moe", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4_moe"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + template_class=ReasoningTemplate, +) + + +# copied from glm4 template +register_template( + name="glm4v", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>", ""], + efficient_eos=True, + mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"), + template_class=ReasoningTemplate, +) + + +# copied from glm4 template +register_template( + name="glm4_5v", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4_moe"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>", ""], + efficient_eos=True, + mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"), + template_class=ReasoningTemplate, +) + + +# copied from glm4 template +register_template( + name="glmz1", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + template_class=ReasoningTemplate, +) + + +register_template( + name="gpt_oss", + format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]), + format_assistant=StringFormatter(slots=["{{content}}<|end|>"]), + format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]), + default_system="You are ChatGPT, a large language model trained by OpenAI.", + thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"), + efficient_eos=True, + template_class=ReasoningTemplate, +) + + +register_template( + name="granite3", + format_user=StringFormatter( + slots=[ + "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" + ] + ), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]), + format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]), +) + + +register_template( + name="granite3_vision", + format_user=StringFormatter(slots=["<|user|>\n{{content}}\n<|assistant|>\n"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}\n"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + mm_plugin=get_mm_plugin(name="llava_next", image_token=""), +) + + +register_template( + name="granite4", + format_user=StringFormatter( + slots=[ + "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" + ] + ), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]), + format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|end_of_text|>\n"], tool_format="default"), + format_observation=StringFormatter( + slots=["<|start_of_role|>tool<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="default"), + stop_words=["<|end_of_text|>"], + default_system="You are Granite, developed by IBM. You are a helpful AI assistant.", +) + + +register_template( + name="index", + format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]), + format_system=StringFormatter(slots=["{{content}}"]), + efficient_eos=True, +) + + +register_template( + name="hunyuan", + format_user=StringFormatter(slots=["{{content}}<|extra_0|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|eos|>"]), + format_system=StringFormatter(slots=["{{content}}<|extra_4|>"]), + format_prefix=EmptyFormatter(slots=["<|startoftext|>"]), + stop_words=["<|eos|>"], +) + + +register_template( + name="hunyuan_small", + format_user=StringFormatter(slots=["<|hy_User|>{{content}}<|hy_place▁holder▁no▁8|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|hy_place▁holder▁no▁2|>"]), + format_system=StringFormatter(slots=["{{content}}<|hy_place▁holder▁no▁3|>"]), + format_prefix=EmptyFormatter(slots=["<|hy_begin▁of▁sentence|>"]), + stop_words=["<|hy_place▁holder▁no▁2|>"], +) + + +register_template( + name="intern2", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory " + "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM (书生·浦语) can understand and communicate fluently in the language " + "chosen by the user such as English and 中文." + ), + stop_words=["<|im_end|>"], +) + + +register_template( + name="intern_vl", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。" + ), + stop_words=["<|im_end|>"], + mm_plugin=get_mm_plugin(name="intern_vl", image_token="", video_token=""], tool_format="default") + tool_calls = json.dumps([FUNCTION] * 2) + assert formatter.apply(content=tool_calls) == [ + """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""" + """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""", + "", + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_default_tool_formatter(): + formatter = ToolFormatter(tool_format="default") + assert formatter.apply(content=json.dumps(TOOLS)) == [ + "You have access to the following tools:\n" + "> Tool Name: test_tool\n" + "Tool Description: tool_desc\n" + "Tool Args:\n" + " - foo (string, required): foo_desc\n" + " - bar (number): bar_desc\n\n" + "Use the following format if using a tool:\n" + "```\n" + "Action: tool name (one of [test_tool])\n" + "Action Input: the input to the tool, in a JSON format representing the kwargs " + """(e.g. ```{"input": "hello world", "num_beams": 5}```)\n""" + "```\n" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_default_tool_extractor(): + formatter = ToolFormatter(tool_format="default") + result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_default_multi_tool_extractor(): + formatter = ToolFormatter(tool_format="default") + result = ( + """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n""" + """Action: another_tool\nAction Input: {"foo": "job", "size": 2}""" + ) + assert formatter.extract(result) == [ + ("test_tool", """{"foo": "bar", "size": 10}"""), + ("another_tool", """{"foo": "job", "size": 2}"""), + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_glm4_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4") + tool_calls = json.dumps(FUNCTION) + assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_glm4_tool_formatter(): + formatter = ToolFormatter(tool_format="glm4") + assert formatter.apply(content=json.dumps(TOOLS)) == [ + "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱 AI 公司训练的语言模型 GLM-4 模型开发的," + "你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具\n\n" + f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n" + "在调用上述函数时,请使用 Json 格式表示调用的参数。" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_glm4_tool_extractor(): + formatter = ToolFormatter(tool_format="glm4") + result = """test_tool\n{"foo": "bar", "size": 10}\n""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_llama3_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3") + tool_calls = json.dumps(FUNCTION) + assert formatter.apply(content=tool_calls) == [ + """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>""" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_llama3_multi_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3") + tool_calls = json.dumps([FUNCTION] * 2) + assert formatter.apply(content=tool_calls) == [ + """[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """ + """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]""" + """<|eot_id|>""" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_llama3_tool_formatter(): + formatter = ToolFormatter(tool_format="llama3") + date = datetime.now().strftime("%d %b %Y") + wrapped_tool = {"type": "function", "function": TOOLS[0]} + assert formatter.apply(content=json.dumps(TOOLS)) == [ + f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n" + "You have access to the following functions. " + "To call a function, please respond with JSON for a function call. " + """Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """ + f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_llama3_tool_extractor(): + formatter = ToolFormatter(tool_format="llama3") + result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_llama3_multi_tool_extractor(): + formatter = ToolFormatter(tool_format="llama3") + result = ( + """[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """ + """{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]""" + ) + assert formatter.extract(result) == [ + ("test_tool", """{"foo": "bar", "size": 10}"""), + ("another_tool", """{"foo": "job", "size": 2}"""), + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_mistral_function_formatter(): + formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", ""], tool_format="mistral") + tool_calls = json.dumps(FUNCTION) + assert formatter.apply(content=tool_calls) == [ + "[TOOL_CALLS] " """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""", + "", + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_mistral_multi_function_formatter(): + formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", ""], tool_format="mistral") + tool_calls = json.dumps([FUNCTION] * 2) + assert formatter.apply(content=tool_calls) == [ + "[TOOL_CALLS] " + """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """ + """{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""", + "", + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_mistral_tool_formatter(): + formatter = ToolFormatter(tool_format="mistral") + wrapped_tool = {"type": "function", "function": TOOLS[0]} + assert formatter.apply(content=json.dumps(TOOLS)) == [ + "[AVAILABLE_TOOLS] " + json.dumps([wrapped_tool], ensure_ascii=False) + "[/AVAILABLE_TOOLS]" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_mistral_tool_extractor(): + formatter = ToolFormatter(tool_format="mistral") + result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_mistral_multi_tool_extractor(): + formatter = ToolFormatter(tool_format="mistral") + result = ( + """[{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}, """ + """{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}]""" + ) + assert formatter.extract(result) == [ + ("test_tool", """{"foo": "bar", "size": 10}"""), + ("another_tool", """{"foo": "job", "size": 2}"""), + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_qwen_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen") + tool_calls = json.dumps(FUNCTION) + assert formatter.apply(content=tool_calls) == [ + """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n<|im_end|>\n""" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_qwen_multi_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen") + tool_calls = json.dumps([FUNCTION] * 2) + assert formatter.apply(content=tool_calls) == [ + """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n\n""" + """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n""" + "<|im_end|>\n" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_qwen_tool_formatter(): + formatter = ToolFormatter(tool_format="qwen") + wrapped_tool = {"type": "function", "function": TOOLS[0]} + assert formatter.apply(content=json.dumps(TOOLS)) == [ + "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + f"\n{json.dumps(wrapped_tool, ensure_ascii=False)}" + "\n\n\nFor each function call, return a json object with function name and arguments within " + """ XML tags:\n\n{"name": , """ + """"arguments": }\n""" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_qwen_tool_extractor(): + formatter = ToolFormatter(tool_format="qwen") + result = """\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_qwen_multi_tool_extractor(): + formatter = ToolFormatter(tool_format="qwen") + result = ( + """\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n\n""" + """\n{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}\n""" + ) + assert formatter.extract(result) == [ + ("test_tool", """{"foo": "bar", "size": 10}"""), + ("another_tool", """{"foo": "job", "size": 2}"""), + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm2_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2") + tool_calls = json.dumps(FUNCTION) + assert formatter.apply(content=tool_calls) == [ + """<|tool_call_start|>[tool_name(foo="bar", size=10)]<|tool_call_end|><|im_end|>\n""" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm2_multi_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2") + tool_calls = json.dumps([FUNCTION] * 2) + assert formatter.apply(content=tool_calls) == [ + """<|tool_call_start|>[tool_name(foo="bar", size=10), tool_name(foo="bar", size=10)]<|tool_call_end|>""" + "<|im_end|>\n" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm2_tool_formatter(): + formatter = ToolFormatter(tool_format="lfm2") + assert formatter.apply(content=json.dumps(TOOLS)) == [ + "List of tools: <|tool_list_start|>" + json.dumps(TOOLS, ensure_ascii=False) + "<|tool_list_end|>" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm2_tool_extractor(): + formatter = ToolFormatter(tool_format="lfm2") + result = """<|tool_call_start|>[test_tool(foo="bar", size=10)]<|tool_call_end|>""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm2_multi_tool_extractor(): + formatter = ToolFormatter(tool_format="lfm2") + result = """<|tool_call_start|>[test_tool(foo="bar", size=10), another_tool(foo="job", size=2)]<|tool_call_end|>""" + assert formatter.extract(result) == [ + ("test_tool", """{"foo": "bar", "size": 10}"""), + ("another_tool", """{"foo": "job", "size": 2}"""), + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm2_tool_extractor_with_nested_dict(): + formatter = ToolFormatter(tool_format="lfm2") + result = """<|tool_call_start|>[search(query="test", options={"limit": 10, "offset": 0})]<|tool_call_end|>""" + extracted = formatter.extract(result) + assert len(extracted) == 1 + assert extracted[0][0] == "search" + args = json.loads(extracted[0][1]) + assert args["query"] == "test" + assert args["options"] == {"limit": 10, "offset": 0} + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm2_tool_extractor_with_list_arg(): + formatter = ToolFormatter(tool_format="lfm2") + result = """<|tool_call_start|>[batch_process(items=[1, 2, 3], enabled=True)]<|tool_call_end|>""" + extracted = formatter.extract(result) + assert len(extracted) == 1 + assert extracted[0][0] == "batch_process" + args = json.loads(extracted[0][1]) + assert args["items"] == [1, 2, 3] + assert args["enabled"] is True + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm2_tool_extractor_no_match(): + formatter = ToolFormatter(tool_format="lfm2") + result = "This is a regular response without tool calls." + extracted = formatter.extract(result) + assert extracted == result + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm2_tool_round_trip(): + formatter = FunctionFormatter(slots=["{{content}}"], tool_format="lfm2") + tool_formatter = ToolFormatter(tool_format="lfm2") + original = {"name": "my_func", "arguments": {"arg1": "hello", "arg2": 42, "arg3": True}} + formatted = formatter.apply(content=json.dumps(original)) + extracted = tool_formatter.extract(formatted[0]) + assert len(extracted) == 1 + assert extracted[0][0] == original["name"] + assert json.loads(extracted[0][1]) == original["arguments"] diff --git a/LlamaFactory/tests/data/test_loader.py b/LlamaFactory/tests/data/test_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..907bda347d18807862f59c5c38d2f162cb035d12 --- /dev/null +++ b/LlamaFactory/tests/data/test_loader.py @@ -0,0 +1,61 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from llamafactory.train.test_utils import load_dataset_module + + +DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data") + +TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3") + +TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA3, + "stage": "sft", + "do_train": True, + "finetuning_type": "full", + "template": "llama3", + "dataset": TINY_DATA, + "dataset_dir": "ONLINE", + "cutoff_len": 8192, + "output_dir": "dummy_dir", + "overwrite_output_dir": True, + "fp16": True, +} + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_load_train_only(): + dataset_module = load_dataset_module(**TRAIN_ARGS) + assert dataset_module.get("train_dataset") is not None + assert dataset_module.get("eval_dataset") is None + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_load_val_size(): + dataset_module = load_dataset_module(val_size=0.1, **TRAIN_ARGS) + assert dataset_module.get("train_dataset") is not None + assert dataset_module.get("eval_dataset") is not None + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_load_eval_data(): + dataset_module = load_dataset_module(eval_dataset=TINY_DATA, **TRAIN_ARGS) + assert dataset_module.get("train_dataset") is not None + assert dataset_module.get("eval_dataset") is not None diff --git a/LlamaFactory/tests/data/test_mm_plugin.py b/LlamaFactory/tests/data/test_mm_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..3187004aa5dcda9f2327370d4af47381a4ff5af2 --- /dev/null +++ b/LlamaFactory/tests/data/test_mm_plugin.py @@ -0,0 +1,433 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest +import torch +from PIL import Image + +from llamafactory.data.mm_plugin import get_mm_plugin +from llamafactory.extras.packages import is_transformers_version_greater_than +from llamafactory.hparams import get_infer_args +from llamafactory.model import load_tokenizer + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, ProcessorMixin + from transformers.image_processing_utils import BaseImageProcessor + + from llamafactory.data.mm_plugin import BasePlugin + from llamafactory.model.loader import TokenizerModule + + +HF_TOKEN = os.getenv("HF_TOKEN") + +TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3") +TINY_LLAMA4 = os.getenv("TINY_LLAMA4", "llamafactory/tiny-random-Llama-4") + +MM_MESSAGES = [ + {"role": "user", "content": "What is in this image?"}, + {"role": "assistant", "content": "A cat."}, +] + +OMNI_MESSAGES = [ + {"role": "user", "content": "What is in this image?"}, + {"role": "assistant", "content": "A cat."}, + {"role": "user", "content": "