diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..20ea40516e24e1bacb8e3434e3a7ca441764ee9b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +model.safetensors.index.json filter=lfs diff=lfs merge=lfs -text +figures/demo_video.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7c219d7247ad815f9c73a93684402da0549e9724 --- /dev/null +++ b/LICENSE @@ -0,0 +1,27 @@ +Modified MIT License + +Copyright (c) 2026 Moonshot AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the “Software”), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Our only modification part is that, if the Software (or any derivative works +thereof) is used for any of your commercial products or services that have +more than 100 million monthly active users, or more than 20 million US dollars +(or equivalent in other currencies) in monthly revenue, you shall prominently +display "Kimi K2.5" on the user interface of such product or service. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3f513061460f718271a97ce127864f4bdecf7449 --- /dev/null +++ b/README.md @@ -0,0 +1,737 @@ +--- +tags: +- unsloth +base_model: +- moonshotai/Kimi-K2.5 +license: other +license_name: modified-mit +library_name: transformers +pipeline_tag: image-text-to-text +--- +> [!NOTE] +> Includes Unsloth **chat template fixes**!
For `llama.cpp`, use `--jinja` +> + +
+

+ Unsloth Dynamic 2.0 achieves superior accuracy & outperforms other leading quants. +

+
+ + + + + + + + + +
+
+ +
+ + Kimi K2.5 + +
+
+
+ Chat + Homepage +
+ +
+ Hugging Face + Twitter Follow + Discord +
+
+ License +
+

+📰  Tech Blog +

+ +## 1. Model Introduction + +Kimi K2.5 is an open-source, native multimodal agentic model built through continual pretraining on approximately 15 trillion mixed visual and text tokens atop Kimi-K2-Base. It seamlessly integrates vision and language understanding with advanced agentic capabilities, instant and thinking modes, as well as conversational and agentic paradigms. + +### Key Features +- **Native Multimodality**: Pre-trained on vision–language tokens, K2.5 excels in visual knowledge, cross-modal reasoning, and agentic tool use grounded in visual inputs. +- **Coding with Vision**: K2.5 generates code from visual specifications (UI designs, video workflows) and autonomously orchestrates tools for visual data processing. +- **Agent Swarm**: K2.5 transitions from single-agent scaling to a self-directed, coordinated swarm-like execution scheme. It decomposes complex tasks into parallel sub-tasks executed by dynamically instantiated, domain-specific agents. + +## 2. Model Summary + +
+ + +| | | +|:---:|:---:| +| **Architecture** | Mixture-of-Experts (MoE) | +| **Total Parameters** | 1T | +| **Activated Parameters** | 32B | +| **Number of Layers** (Dense layer included) | 61 | +| **Number of Dense Layers** | 1 | +| **Attention Hidden Dimension** | 7168 | +| **MoE Hidden Dimension** (per Expert) | 2048 | +| **Number of Attention Heads** | 64 | +| **Number of Experts** | 384 | +| **Selected Experts per Token** | 8 | +| **Number of Shared Experts** | 1 | +| **Vocabulary Size** | 160K | +| **Context Length** | 256K | +| **Attention Mechanism** | MLA | +| **Activation Function** | SwiGLU | +| **Vision Encoder** | MoonViT | +| **Parameters of Vision Encoder** | 400M | +
+ +## 3. Evaluation Results + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
BenchmarkKimi K2.5
(Thinking)
GPT-5.2
(xhigh)
Claude 4.5 Opus
(Extended Thinking)
Gemini 3 Pro
(High Thinking Level)
DeepSeek V3.2
(Thinking)
Qwen3-VL-
235B-A22B-
Thinking
Reasoning & Knowledge
HLE-Full30.134.530.837.525.1-
HLE-Full
(w/ tools)
50.245.543.245.840.8-
AIME 202596.110092.895.093.1-
HMMT 2025 (Feb)95.499.492.9*97.3*92.5-
IMO-AnswerBench81.886.378.5*83.1*78.3-
GPQA-Diamond87.692.487.091.982.4-
MMLU-Pro87.186.7*89.3*90.185.0-
Image & Video
MMMU-Pro78.579.5*74.081.0-69.3
CharXiv (RQ)77.582.167.2*81.4-66.1
MathVision84.283.077.1*86.1*-74.6
MathVista (mini)90.182.8*80.2*89.8*-85.8
ZeroBench99*3*8*-4*
ZeroBench
(w/ tools)
117*9*12*-3*
OCRBench92.380.7*86.5*90.3*-87.5
OmniDocBench 1.588.885.787.7*88.5-82.0*
InfoVQA (val)92.684*76.9*57.2*-89.5
SimpleVQA71.255.8*69.7*69.7*-56.8*
WorldVQA46.328.036.847.4-23.5
VideoMMMU86.685.984.4*87.6-80.0
MMVU80.480.8*77.377.5-71.1
MotionBench70.464.860.370.3--
VideoMME87.486.0*-88.4*-79.0
LongVideoBench79.876.5*67.2*77.7*-65.6*
LVBench75.9--73.5*-63.6
Coding
SWE-Bench Verified76.880.080.976.273.1-
SWE-Bench Pro50.755.655.4*---
SWE-Bench Multilingual73.072.077.565.070.2-
Terminal Bench 2.050.854.059.354.246.4-
PaperBench63.563.7*72.9*-47.1-
CyberGym41.3-50.639.9*17.3*-
SciCode48.752.149.556.138.9-
OJBench (cpp)57.4-54.6*68.5*54.7*-
LiveCodeBench (v6)85.0-82.2*87.4*83.3-
Long Context
Longbench v261.054.5*64.4*68.2*59.8*-
AA-LCR70.072.3*71.3*65.3*64.3*-
Agentic Search
BrowseComp60.665.837.037.851.4-
BrowseComp
(w/ctx manage)
74.957.859.267.6-
BrowseComp
(Agent Swarm)
78.4-----
WideSearch
(item-f1)
72.7-76.2*57.032.5*-
WideSearch
(item-f1 Agent Swarm)
79.0-----
DeepSearchQA77.171.3*76.1*63.2*60.9*-
FinSearchCompT2&T367.8-66.2*49.959.1*-
Seal-057.445.047.7*45.5*49.5*-
+
+ +
+Footnotes + +1. General Testing Details + - We report results for Kimi K2.5 and DeepSeek-V3.2 with thinking mode enabled, Claude Opus 4.5 with extended thinking mode, GPT-5.2 with xhigh reasoning effort, and Gemini 3 Pro with a high thinking level. For vision benchmarks, we additionally report results for Qwen3-VL-235B-A22B-Thinking. + - Unless otherwise specified, all Kimi K2.5 experiments were conducted with temperature = 1.0, top-p = 0.95, and a context length of 256k tokens. + - Benchmarks without publicly available scores were re-evaluated under the same conditions used for Kimi K2.5 and are marked with an asterisk (*). + - We could not evaluate GPT-5.2 xhigh on all benchmarks due to service stability issues. For benchmarks that were not tested, we mark them as "-". +2. Text and Reasoning + - HLE, AIME 2025, HMMT 2025 (Feb), and GPQA-Diamond were evaluated with a maximum completion budget of 96k tokens. + - Results for AIME and HMMT are averaged over 32 runs (avg@32); GPQA-Diamond over 8 runs (avg@8). + - For HLE, we report scores on the full set (text & image). Kimi K2.5 scores 31.5 (text) and 21.3 (image) without tools, and 51.8 (text) and 39.8 (image) with tools. The DeepSeek-V3.2 score corresponds to its text-only subset (marked with †) . Hugging Face access was blocked to prevent potential data leakage. HLE with tools uses simple context management: once the context exceeds a threshold, only the latest round of tool messages is retained. +3. Tool-Augmented / Agentic Search + - Kimi K2.5 was equipped with search, code-interpreter, and web-browsing tools for HLE with tools and all agentic search benchmarks. + - Except for BrowseComp (where K2.5 and DeepSeek-V3.2 used the discard-all strategy), no context management was applied, and tasks exceeding the supported context length were directly counted as failed. + - The test system prompts emphasize deep and proactive tool use, instructing models to reason carefully, leverage tools, and verify uncertain information. Full prompts will be provided in the technical report. + - Results for Seal-0 and WideSearch are averaged over four runs (avg@4). +4. Vision Benchmarks + - Max-tokens = 64k, averaged over three runs (avg@3). + - ZeroBench (w/ tools) uses max-tokens-per-step = 24k and max-steps = 30 for multi-step reasoning. + - MMMU-Pro follows the official protocol, preserving input order and prepending images. + - GPT-5.2-xhigh had ~10% failure rate (no output despite 3 retries), treated as incorrect; reported scores likely underestimate true performance. + - WorldVQA, a benchmark designed to evaluate atomic vision-centric world knowledge. Access WorldVQA at https://github.com/MoonshotAI/WorldVQA. + - OmniDocBench Score is computed as (1 − normalized Levenshtein distance) × 100, where a higher score denotes superior accuracy. +5. Coding Tasks + - Terminal-Bench 2.0 scores were obtained with the default agent framework (Terminus-2) and the provided JSON parser. In our implementation, we evaluated Terminal-Bench 2.0 under non-thinking mode. This choice was made because our current context management strategy for the thinking mode is incompatible with Terminus-2. + - For the SWE-Bench series of evaluations (including verified, multilingual, and pro), we used an internally developed evaluation framework. This framework includes a minimal set of tools—bash tool, createfile tool, insert tool, view tool, strreplace tool, and submit tool—along with tailored system prompts designed for the tasks. The highest scores were achieved under non-thinking mode. + - The score of Claude Opus 4.5 on CyberGym is reported under the non-thinking setting. + - All reported scores of coding tasks are averaged over 5 independent runs. +6. Long-Context Benchmarks + - AA-LCR: scores averaged over three runs (avg@3). + - LongBench-V2: identical prompts and input contexts standardized to ~128k tokens. +7. Agent Swarm + - BrowseComp (Swarm Mode): main agent max 15 steps; sub-agents max 100 steps. + - WideSearch (Swarm Mode): main and sub-agents max 100 steps. + +
+ +## 4. Native INT4 Quantization +Kimi-K2.5 adopts the same native int4 quantization method as [Kimi-K2-Thinking](https://huggingface.co/moonshotai/Kimi-K2-Thinking#4-native-int4-quantization). + +## 5. Deployment +> [!Note] +> You can access Kimi-K2.5's API on https://platform.moonshot.ai , we provide OpenAI/Anthropic-compatible API for you. To verify the deployment is correct, we also provide the [Kimi Vendor Verifier](https://kimi.com/blog/kimi-vendor-verifier.html). +Currently, Kimi-K2.5 is recommended to run on the following inference engines: +* vLLM +* SGLang +* KTransformers + +Deployment examples can be found in the [Model Deployment Guide](docs/deploy_guidance.md). + +--- +## 6. Model Usage + +The usage demos below demonstrate how to call our official API. + +For third-party API deployed with vLLM or SGLang, please note that : +> [!Note] +> - Chat with video content is an experimental feature and is only supported in our official API for now +> +> - The recommended `temperature` will be `1.0` for Thinking mode and `0.6` for Instant mode. +> +> - The recommended `top_p` is `0.95` +> +> - To use instant mode, you need to pass `{'chat_template_kwargs': {"thinking": False}}` in `extra_body`. + +### Chat Completion + +This is a simple chat completion script which shows how to call K2.5 API in Thinking and Instant modes. + +```python +import openai +import base64 +import requests +def simple_chat(client: openai.OpenAI, model_name: str): + messages = [ + {'role': 'system', 'content': 'You are Kimi, an AI assistant created by Moonshot AI.'}, + { + 'role': 'user', + 'content': [ + {'type': 'text', 'text': 'which one is bigger, 9.11 or 9.9? think carefully.'} + ], + }, + ] + response = client.chat.completions.create( + model=model_name, messages=messages, stream=False, max_tokens=4096 + ) + print('===== Below is reasoning_content in Thinking Mode ======') + print(f'reasoning content: {response.choices[0].message.reasoning_content}') + print('===== Below is response in Thinking Mode ======') + print(f'response: {response.choices[0].message.content}') + + # To use instant mode, pass {"thinking" = {"type":"disabled"}} + response = client.chat.completions.create( + model=model_name, + messages=messages, + stream=False, + max_tokens=4096, + extra_body={'thinking': {'type': 'disabled'}}, # this is for official API + # extra_body= {'chat_template_kwargs': {"thinking": False}} # this is for vLLM/SGLang + ) + print('===== Below is response in Instant Mode ======') + print(f'response: {response.choices[0].message.content}') +``` + + +### Chat Completion with visual content + +K2.5 supports Image and Video input. + +The following example demonstrates how to call K2.5 API with image input: + +```python +import openai +import base64 +import requests + +def chat_with_image(client: openai.OpenAI, model_name: str): + url = 'https://huggingface.co/moonshotai/Kimi-K2.5/resolve/main/figures/kimi-logo.png' + image_base64 = base64.b64encode(requests.get(url).content).decode() + messages = [ + { + 'role': 'user', + 'content': [ + {'type': 'text', 'text': 'Describe this image in detail.'}, + { + 'type': 'image_url', + 'image_url': {'url': f'data:image/png;base64, {image_base64}'}, + }, + ], + } + ] + + response = client.chat.completions.create( + model=model_name, messages=messages, stream=False, max_tokens=8192 + ) + print('===== Below is reasoning_content in Thinking Mode ======') + print(f'reasoning content: {response.choices[0].message.reasoning_content}') + print('===== Below is response in Thinking Mode ======') + print(f'response: {response.choices[0].message.content}') + + # Also support instant mode if pass {"thinking" = {"type":"disabled"}} + response = client.chat.completions.create( + model=model_name, + messages=messages, + stream=False, + max_tokens=4096, + extra_body={'thinking': {'type': 'disabled'}}, # this is for official API + # extra_body= {'chat_template_kwargs': {"thinking": False}} # this is for vLLM/SGLang + ) + print('===== Below is response in Instant Mode ======') + print(f'response: {response.choices[0].message.content}') + + return response.choices[0].message.content +``` + +The following example demonstrates how to call K2.5 API with video input: + +```python +import openai +import base64 +import requests + +def chat_with_video(client: openai.OpenAI, model_name:str): + url = 'https://huggingface.co/moonshotai/Kimi-K2.5/resolve/main/figures/demo_video.mp4' + video_base64 = base64.b64encode(requests.get(url).content).decode() + messages = [ + { + "role": "user", + "content": [ + {"type": "text","text": "Describe the video in detail."}, + { + "type": "video_url", + "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}, + }, + ], + } + ] + + response = client.chat.completions.create(model=model_name, messages=messages) + print('===== Below is reasoning_content in Thinking Mode ======') + print(f'reasoning content: {response.choices[0].message.reasoning_content}') + print('===== Below is response in Thinking Mode ======') + print(f'response: {response.choices[0].message.content}') + + # Also support instant mode if pass {"thinking" = {"type":"disabled"}} + response = client.chat.completions.create( + model=model_name, + messages=messages, + stream=False, + max_tokens=4096, + extra_body={'thinking': {'type': 'disabled'}}, # this is for official API + # extra_body= {'chat_template_kwargs': {"thinking": False}} # this is for vLLM/SGLang + ) + print('===== Below is response in Instant Mode ======') + print(f'response: {response.choices[0].message.content}') + return response.choices[0].message.content +``` + +### Interleaved Thinking and Multi-Step Tool Call + +K2.5 shares the same design of Interleaved Thinking and Multi-Step Tool Call as K2 Thinking. For usage example, please refer to the [K2 Thinking documentation](https://platform.moonshot.ai/docs/guide/use-kimi-k2-thinking-model#complete-example). + + +### Coding Agent Framework + +Kimi K2.5 works best with Kimi Code CLI as its agent framework — give it a try at https://www.kimi.com/code. + + +--- + +## 7. License + +Both the code repository and the model weights are released under the [Modified MIT License](LICENSE). + +--- + +## 8. Third Party Notices + +See [THIRD PARTY NOTICES](THIRD_PARTY_NOTICES.md) + +--- + +## 9. Contact Us + +If you have any questions, please reach out at [support@moonshot.cn](mailto:support@moonshot.cn). diff --git a/THIRD_PARTY_NOTICES.md b/THIRD_PARTY_NOTICES.md new file mode 100644 index 0000000000000000000000000000000000000000..c558728752e493c3764a7abdd1281e3d12bfed1d --- /dev/null +++ b/THIRD_PARTY_NOTICES.md @@ -0,0 +1,43 @@ +# THIRD_PARTY_NOTICES + +This file lists third-party software contained in Kimi-K2.5 along with their licenses, in compliance with the redistribution clauses of those licenses. + +--- + +## 1. DeepSeek-V3 + +Our model archietecture is DeepSeek-V3-like. Some of modeling codes are copied from the source repository. + +- **Source Repository** + https://huggingface.co/deepseek-ai/DeepSeek-V3 + +- **Files / Directories Used** + - configuration_deepseek.py + - modeling_deepseek.py + +- **License Type** + MIT License + +- **Copyright Notice** + Copyright (c) 2023 DeepSeek + +- **Full License Text** +``` +MIT License +Copyright (c) 2023 DeepSeek +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +``` diff --git a/chat_template.jinja b/chat_template.jinja new file mode 100644 index 0000000000000000000000000000000000000000..bc2867858b843a9de74c2649bf4b1dfd74039df3 --- /dev/null +++ b/chat_template.jinja @@ -0,0 +1,112 @@ +{%- macro render_content(msg) -%} + {%- set c = msg.get('content') -%} + {%- if c is string -%} + {{ c }} + {%- elif c is not none -%} + {% for content in c -%} + {% if content['type'] == 'image' or content['type'] == 'image_url' -%} + <|media_start|>image<|media_content|><|media_pad|><|media_end|> + {% elif content['type'] == 'video' or content['type']== 'video_url'-%} + <|kimi_k25_video_placeholder|> + {% else -%} + {{ content['text'] }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} +{%- endmacro -%} + +{% macro set_roles(message) -%} + {%- set role_name = message.get('name') or message['role'] -%} + {%- if message['role'] == 'user' -%} + <|im_user|>{{role_name}}<|im_middle|> + {%- elif message['role'] == 'assistant' -%} + <|im_assistant|>{{role_name}}<|im_middle|> + {%- else -%} + <|im_system|>{{role_name}}<|im_middle|> + {%- endif -%} +{%- endmacro -%} + + +{%- macro render_toolcalls(message) -%} + <|tool_calls_section_begin|> + {%- for tool_call in message['tool_calls'] -%} + {%- set formatted_id = tool_call['id'] -%} + <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|> + {%- endfor -%} + <|tool_calls_section_end|> +{%- endmacro -%} + + +{# Find last non-tool-call assisitant message #} +{%- set ns = namespace(last_non_tool_call_assistant_msg=-1) -%} +{%- for idx in range(messages|length-1, -1, -1) -%} + {%- if messages[idx]['role'] == 'assistant' and not messages[idx].get('tool_calls') -%} + {%- set ns.last_non_tool_call_assistant_msg = idx -%} + {%- break -%} + {%- endif -%} +{%- endfor -%} + +{# split all messages into history & suffix, reasoning_content in suffix should be reserved.#} +{%- set hist_msgs = messages[:ns.last_non_tool_call_assistant_msg+1] -%} +{%- set suffix_msgs = messages[ns.last_non_tool_call_assistant_msg+1:] -%} + +{%- if tools -%} + {%- if tools_ts_str -%} + <|im_system|>tool_declare<|im_middle|>{{ tools_ts_str }}<|im_end|> + {%- else -%} + <|im_system|>tool_declare<|im_middle|>{{ tools | tojson(separators=(',', ':')) }}<|im_end|> + {%- endif -%} +{%- endif -%} + +{%- if messages|length == 0 or messages[0]['role'] != 'system' -%} + <|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|> +{%- endif -%} + +{%- for message in hist_msgs -%} + {{set_roles(message)}} + {%- if message['role'] == 'assistant' -%} + {{render_content(message)}} + {%- if message.get('tool_calls') -%} + {{render_toolcalls(message)}} + {%- endif -%} + {%- elif message['role'] == 'tool' -%} + {%- set tool_call_id = message.tool_call_id -%} + ## Return of {{ tool_call_id }} +{{render_content(message)}} + {%- elif message['content'] is not none -%} + {{render_content(message)}} + {%- endif -%} + <|im_end|> +{%- endfor -%} + +{%- for message in suffix_msgs -%} + {{set_roles(message)}} + {%- if message['role'] == 'assistant' -%} + {%- if thinking is defined and thinking is false -%} + {{render_content(message)}} + {%- else -%} + {%- set rc = message.get('reasoning_content', '') -%} + {{rc}}{{render_content(message)}} + {%- endif -%} + {%- if message.get('tool_calls') -%} + {{render_toolcalls(message)}} + {%- endif -%} + {%- elif message['role'] == 'tool' -%} + {%- set tool_call_id = message.tool_call_id -%} + ## Return of {{ tool_call_id }} +{{render_content(message)}} + {%- elif message['content'] is not none -%} + {{render_content(message)}} + {%- endif -%} + <|im_end|> +{%- endfor -%} + + +{%- if add_generation_prompt -%} + <|im_assistant|>assistant<|im_middle|> + {%- if thinking is defined and thinking is false -%} + + {%- else -%} + + {%- endif -%} +{%- endif -%} \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000000000000000000000000000000000000..cdb3bfcc8d7f395854b7ef4ff225b0f86f7a9161 --- /dev/null +++ b/config.json @@ -0,0 +1,226 @@ +{ + "architectures": [ + "KimiK25ForConditionalGeneration" + ], + "auto_map": { + "AutoConfig": "configuration_kimi_k25.KimiK25Config", + "AutoModel": "modeling_kimi_k25.KimiK25ForConditionalGeneration", + "AutoModelForCausalLM": "modeling_kimi_k25.KimiK25ForConditionalGeneration" + }, + "bos_token_id": 163584, + "torch_dtype": "bfloat16", + "eos_token_id": 163585, + "ignore_index": -100, + "media_placeholder_token_id": 163605, + "model_type": "kimi_k25", + "pad_token_id": 163839, + "quantization_config": { + "config_groups": { + "group_0": { + "input_activations": null, + "output_activations": null, + "targets": [ + "Linear" + ], + "weights": { + "actorder": null, + "block_structure": null, + "dynamic": false, + "group_size": 32, + "num_bits": 4, + "observer": "minmax", + "observer_kwargs": {}, + "strategy": "group", + "symmetric": true, + "type": "int" + } + } + }, + "format": "pack-quantized", + "ignore": [ + "lm_head", + "re:.*self_attn.*", + "re:.*shared_experts.*", + "re:.*mlp\\.(gate|up|gate_up|down)_proj.*" + ], + "kv_cache_scheme": null, + "quant_method": "compressed-tensors", + "quantization_status": "compressed" + }, + "text_config": { + "_name_or_path": "", + "add_cross_attention": false, + "architectures": [ + "DeepseekV3ForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_deepseek.DeepseekV3Config", + "AutoModel": "modeling_deepseek.DeepseekV3Model", + "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM" + }, + "aux_loss_alpha": 0.001, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": 163584, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "torch_dtype": "bfloat16", + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 163585, + "ep_size": 1, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "first_k_dense_replace": 1, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "hidden_act": "silu", + "hidden_size": 7168, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "initializer_range": 0.02, + "intermediate_size": 18432, + "is_decoder": false, + "is_encoder_decoder": false, + "kv_lora_rank": 512, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 262144, + "min_length": 0, + "model_type": "deepseek_v3", + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "n_group": 1, + "n_routed_experts": 384, + "n_shared_experts": 1, + "no_repeat_ngram_size": 0, + "norm_topk_prob": true, + "num_attention_heads": 64, + "num_beam_groups": 1, + "num_beams": 1, + "num_experts_per_tok": 8, + "num_hidden_layers": 61, + "num_key_value_heads": 64, + "num_nextn_predict_layers": 0, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": 163839, + "prefix": null, + "pretraining_tp": 1, + "problem_type": null, + "pruned_heads": {}, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "quantization_config": { + "config_groups": { + "group_0": { + "input_activations": null, + "output_activations": null, + "targets": [ + "Linear" + ], + "weights": { + "actorder": null, + "block_structure": null, + "dynamic": false, + "group_size": 32, + "num_bits": 4, + "observer": "minmax", + "observer_kwargs": {}, + "strategy": "group", + "symmetric": true, + "type": "int" + } + } + }, + "format": "pack-quantized", + "ignore": [ + "lm_head", + "re:.*self_attn.*", + "re:.*shared_experts.*", + "re:.*mlp\\.(gate|up|gate_up|down)_proj.*" + ], + "kv_cache_scheme": null, + "quant_method": "compressed-tensors", + "quantization_status": "compressed" + }, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "beta_fast": 32.0, + "beta_slow": 1.0, + "factor": 64.0, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn" + }, + "rope_theta": 50000.0, + "routed_scaling_factor": 2.827, + "scoring_func": "sigmoid", + "sep_token_id": null, + "seq_aux": true, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": false, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "topk_group": 1, + "topk_method": "noaux_tc", + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false, + "use_cache": true, + "v_head_dim": 128, + "vocab_size": 163840 + }, + "tie_word_embeddings": false, + "transformers_version": "4.57.3", + "unsloth_fixed": true, + "use_unified_vision_chunk": true, + "video_placeholder": "<|kimi_k25_video_placeholder|>", + "vision_config": { + "init_pos_emb_height": 64, + "init_pos_emb_time": 4, + "init_pos_emb_width": 64, + "merge_kernel_size": [ + 2, + 2 + ], + "merge_type": "sd2_tpool", + "mm_hidden_size": 1152, + "mm_projector_type": "patchmerger", + "model_type": "", + "patch_size": 14, + "pos_emb_type": "divided_fixed", + "projector_hidden_act": "gelu", + "projector_ln_eps": 1e-05, + "text_hidden_size": 7168, + "video_attn_type": "spatial_temporal", + "vt_hidden_size": 1152, + "vt_intermediate_size": 4304, + "vt_num_attention_heads": 16, + "vt_num_hidden_layers": 27 + } +} diff --git a/configuration_deepseek.py b/configuration_deepseek.py new file mode 100644 index 0000000000000000000000000000000000000000..b3152dd7c3e53d223d561848dc967f487daf32ef --- /dev/null +++ b/configuration_deepseek.py @@ -0,0 +1,214 @@ +# Copy from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/configuration_deepseek.py + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class DeepseekV3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size=2048, + num_hidden_layers=61, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts=1, + n_routed_experts=256, + ep_size=1, + routed_scaling_factor=2.5, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method='noaux_tc', + n_group=8, + topk_group=4, + num_experts_per_tok=8, + moe_layer_freq=1, + first_k_dense_replace=3, + norm_topk_prob=True, + scoring_func='sigmoid', + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/configuration_kimi_k25.py b/configuration_kimi_k25.py new file mode 100644 index 0000000000000000000000000000000000000000..5858b3290a32509480affd58abc01482d5976550 --- /dev/null +++ b/configuration_kimi_k25.py @@ -0,0 +1,123 @@ +from transformers.configuration_utils import PretrainedConfig + +try: + from configuration_deepseek import DeepseekV3Config +except ImportError: + from .configuration_deepseek import DeepseekV3Config + + +class KimiK25VisionConfig(PretrainedConfig): + + def __init__( + self, + patch_size: int = 14, + init_pos_emb_height: int = 64, + init_pos_emb_width: int = 64, + init_pos_emb_time: int = 4, + pos_emb_type: str = 'divided_fixed', + vt_num_attention_heads: int = 16, + vt_num_hidden_layers: int = 27, + vt_hidden_size: int = 1152, + vt_intermediate_size: int = 4304, + merge_kernel_size: tuple = (2, 2), + video_attn_type: str = 'spatial_temporal', + merge_type: str = 'sd2_tpool', + _attn_implementation: str = 'flash_attention_2', + # MM Projector parameters + mm_projector_type: str = 'patchmerger', + mm_hidden_size: int | None = None, + projector_hidden_act: str = "gelu", + projector_ln_eps: float = 1e-5, + # Other parameters + ignore_index: int = -100, + media_placeholder_token_id: int = 163605, + pad_token_id: int = 0, + use_unified_vision_chunk: bool = True, + video_placeholder="<|kimi_k25_video_placeholder|>", + text_hidden_size=7168, + **vision_config_kwargs): + + self.patch_size = patch_size + self.init_pos_emb_height = init_pos_emb_height + self.init_pos_emb_width = init_pos_emb_width + self.init_pos_emb_time = init_pos_emb_time + self.pos_emb_type = pos_emb_type + self.vt_num_attention_heads = vt_num_attention_heads + self.vt_num_hidden_layers = vt_num_hidden_layers + self.vt_hidden_size = vt_hidden_size + self.vt_intermediate_size = vt_intermediate_size + self.merge_kernel_size = merge_kernel_size + self.video_attn_type = video_attn_type + self.merge_type = merge_type + self._attn_implementation = _attn_implementation + + # MM Projector config + self.mm_projector_type = mm_projector_type + self.mm_hidden_size = mm_hidden_size if mm_hidden_size is not None else vt_hidden_size + self.projector_hidden_act = projector_hidden_act + self.projector_ln_eps = projector_ln_eps + self.text_hidden_size = text_hidden_size + + +class KimiK25Config(PretrainedConfig): + """Kimi-K2.5 model configuration. + + Args: + text_config (dict | DeepseekV3Config): Configuration for the text model. + + Vision Tower Parameters (from MoonViT3dConfig): + patch_size (int): Patch size for vision tower. + init_pos_emb_height (int): Initial position embedding height. + init_pos_emb_width (int): Initial position embedding width. + init_pos_emb_time (int): Initial position embedding time dimension. + pos_emb_type (str): Type of position embedding. + vt_num_attention_heads (int): Number of attention heads in vision tower. + vt_num_hidden_layers (int): Number of hidden layers in vision tower. + vt_hidden_size (int): Hidden size of vision tower. + vt_intermediate_size (int): Intermediate size in vision tower FFN. + merge_kernel_size (tuple): Kernel size for patch merging. + video_attn_type (str): Type of video attention. + merge_type (str): Type of merge operation. + _attn_implementation (str): Attention implementation type. + + MM Projector Parameters (from MultiModalProjectorConfig): + mm_projector_type (str): Type of multimodal projector. + mm_hidden_size (int): Hidden size from vision tower (should match vt_hidden_size). + projector_hidden_act (str): Activation function for projector. + projector_ln_eps (float): Layer norm epsilon for projector. + + Other Parameters: + ignore_index (int): The ignore index for the loss function. + media_placeholder_token_id (int): The token ID to use for media placeholders. + pad_token_id (int): The token ID to use for padding. + """ + + model_type = "kimi_k25" + + def __init__( + self, + text_config: dict | DeepseekV3Config = None, + vision_config: dict | KimiK25VisionConfig = None, + # Other parameters + ignore_index: int = -100, + media_placeholder_token_id: int = 163605, + pad_token_id: int = 0, + use_unified_vision_chunk: bool = True, + video_placeholder="<|kimi_k25_video_placeholder|>", + **kwargs, + ): + if isinstance(text_config, dict): + text_config = DeepseekV3Config(**text_config) + if isinstance(vision_config, dict): + vision_config = KimiK25VisionConfig(**vision_config) + self.text_config = text_config + self.vision_config = vision_config + # Other config + self.ignore_index = ignore_index + self.media_placeholder_token_id = media_placeholder_token_id + self.use_unified_vision_chunk = use_unified_vision_chunk + self.video_placeholder = video_placeholder + if getattr(self.text_config, "quantization_config", None) is not None: + self.quantization_config = self.text_config.quantization_config + + super().__init__(pad_token_id=pad_token_id, **kwargs) diff --git a/docs/deploy_guidance.md b/docs/deploy_guidance.md new file mode 100644 index 0000000000000000000000000000000000000000..696b7586364b1598d608adfeb7070a5a973007b7 --- /dev/null +++ b/docs/deploy_guidance.md @@ -0,0 +1,82 @@ +# Kimi-K2.5 Deployment Guide + +> [!Note] +> This guide only provides some examples of deployment commands for Kimi-K2.5, which may not be the optimal configuration. Since inference engines are still being updated frequenty, please continue to follow the guidance from their homepage if you want to achieve better inference performance. + +> kimi_k2 reasoning parser and other related features have been merged into vLLM/sglang and will be available in the next release. For now, please use the nightly build Docker image. +## vLLM Deployment + +This model is available in nightly vLLM wheel: +``` +uv pip install -U vllm \ + --torch-backend=auto \ + --extra-index-url https://wheels.vllm.ai/nightly +``` + +Here is the example to serve this model on a H200 single node with TP8 via vLLM: +```bash +vllm serve $MODEL_PATH --tp 8 --trust-remote-code --tool-call-parser kimi_k2 --reasoning-parser kimi_k2 +``` +**Key notes** +- `--tool-call-parser kimi_k2`: Required for enabling tool calling +- `--reasoning-parser kimi_k2`: Kimi-K2.5 enables thinking mode by default. Make sure to pass this for correct reasoning processing. + +## SGLang Deployment + +This model is available in SGLang latest main: + +``` +pip install "sglang @ git+https://github.com/sgl-project/sglang.git#subdirectory=python" +pip install nvidia-cudnn-cu12==9.16.0.29 +``` + +Similarly, here is the example for it to run with TP8 on H200 in a single node via SGLang: +``` bash +sglang serve --model-path $MODEL_PATH --tp 8 --trust-remote-code --tool-call-parser kimi_k2 --reasoning-parser kimi_k2 +``` +**Key parameter notes:** +- `--tool-call-parser kimi_k2`: Required when enabling tool usage. +- `--reasoning-parser kimi_k2`: Required for correctly processing reasoning content. + +## KTransformers Deployment +### KTransformers+SGLang Inference Deployment +Launch with KTransformers + SGLang for CPU+GPU heterogeneous inference: + +``` +python -m sglang.launch_server \ + --model path/to/Kimi-K2.5/ \ + --kt-amx-weight-path path/to/Kimi-K2.5/ \ + --kt-cpuinfer 64 \ + --kt-threadpool-count 2 \ + --kt-num-gpu-experts 180 \ + --kt-amx-method AMXINT4 \ + --trust-remote-code \ + --mem-fraction-static 0.98 \ + --chunked-prefill-size 16384 \ + --max-running-requests 48 \ + --max-total-tokens 50000 \ + --tensor-parallel-size 8 \ + --enable-p2p-check \ + --disable-shared-experts-fusion +``` + +Achieves 640.12 tokens/s Prefill and 24.51 tokens/s Decode (48-way concurrency) on 8× NVIDIA L20 + 2× Intel 6454S. + +More details: https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/Kimi-K2.5.md . + +### KTransformers+LLaMA-Factory Fine-tuning Deployment + +You can use below command to run LoRA SFT with KT+llamafactory. + +``` +# For LoRA SFT +USE_KT=1 llamafactory-cli train examples/train_lora/kimik2_lora_sft_kt.yaml +# For Chat with model after LoRA SFT +llamafactory-cli chat examples/inference/kimik2_lora_sft_kt.yaml +# For API with model after LoRA SFT +llamafactory-cli api examples/inference/kimik2_lora_sft_kt.yaml +``` + +This achieves end-to-end LoRA SFT Throughput: 44.55 token/s on 2× NVIDIA 4090 + Intel 8488C with 1.97T RAM and 200G swap memory. + +More details refer to https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/SFT_Installation_Guide_KimiK2.5.md . diff --git a/figures/demo_video.mp4 b/figures/demo_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d4d34d31bd1a855188d341793a3144d7f97dd7d4 --- /dev/null +++ b/figures/demo_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09b4d925aa0a7c712feef50765355f0625d8f6d46ea302fd98db9609e9070047 +size 270100 diff --git a/figures/kimi-logo.png b/figures/kimi-logo.png new file mode 100644 index 0000000000000000000000000000000000000000..870b8be6e07cc2c46f7173e800fbaff8af0af5d1 Binary files /dev/null and b/figures/kimi-logo.png differ diff --git a/generation_config.json b/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..4bb3f8ec4a1d604598b7ffb4621b955e995bda92 --- /dev/null +++ b/generation_config.json @@ -0,0 +1,4 @@ +{ + "max_length": 262144, + "eos_token_id": 163586 +} \ No newline at end of file diff --git a/kimi_k25_processor.py b/kimi_k25_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..d526032f91036de5f3d226b866acf449553b986d --- /dev/null +++ b/kimi_k25_processor.py @@ -0,0 +1,165 @@ +from transformers.feature_extraction_utils import BatchFeature +from transformers.processing_utils import ProcessorMixin +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class KimiK25Processor(ProcessorMixin): + r""" + Constructs a KimiK25 processor which wraps a KimiK25 image processor and a tokenizer into a single processor. + + [`KimiK25Processor`] offers all the functionalities of [`KimiK25ImageProcessor`] and [`TikTokenTokenizer`]. See the + [`~KimiK25Processor.__call__`] and [`~KimiK25Processor.decode`] for more information. + + Args: + image_processor ([`KimiK25ImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`TikTokenTokenizer`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + **kwargs, + ): + super().__init__(image_processor, + tokenizer, + chat_template=chat_template) + self.media_processor = image_processor + # A special temporal placeholder to be replaced by actual video placeholders + self.video_placeholder = "<|kimi_k25_video_placeholder|>" + + def update_raw_text(self, text: str, video_prompts: list[str]) -> str: + # replace video prompt in text with video chunk prompts + video_count = text.count(self.video_placeholder) + if video_count == 0: + return text + assert video_count == len(video_prompts) + text_parts = text.split(self.video_placeholder) + assert len(text_parts) == len(video_prompts) + 1 + text = "".join([ + text_parts[i] + video_prompts[i] for i in range(len(video_prompts)) + ]) + text += text_parts[-1] + return text + + def preprocess_medias(self, medias: list[dict]) -> list[dict]: + updated_medias = [] + video_prompts = [] + for media in medias: + if media['type'] == 'image': + updated_medias.append(media) + elif media['type'] == 'video': + video_chunks = self.media_processor.split_video_chunks( + media['video']) + updated_medias.extend(video_chunks) + video_prompts.append("".join( + [vc['prompt'] for vc in video_chunks])) + else: + raise ValueError(f"unsupported media type: {media['type']}") + return updated_medias, video_prompts + + def __call__(self, + messages: list[dict] = None, + medias: list[dict] = None, + text: str = None, + return_tensors: str = "pt", + **kwargs) -> BatchFeature: + """ + Process multimodal inputs for Kimi-K2.5 model. + + This processor accepts ordered messages and extracts both media and text in a single pass. + text will be automatically updated if video input detected in messages + + Args: + messages: List of message dicts with 'role' and 'content' fields. + If provided, medias and text will be extracted automatically. + medias: Pre-extracted list of media dicts. If None, extracted from messages. + text: Pre-formatted text string. If None, generated via apply_chat_template. + return_tensors: Format of returned tensors ('pt', 'np', 'tf'). Default: 'pt'. + **kwargs: Additional arguments passed to tokenizer.apply_chat_template. + + Returns: + BatchFeature with fields: input_ids, attention_mask, pixel_values, grid_thws. + """ + if messages is None and (medias is None or text is None): + raise ValueError( + "Provide either 'messages' or both 'medias' and 'text'") + + if medias is not None and text is not None: + updated_medias, video_prompts = self.preprocess_medias(medias) + preprocessed = self.media_processor.preprocess( + updated_medias, return_tensors=return_tensors) + text = self.update_raw_text(text, video_prompts) + text_inputs = self.tokenizer(text, return_tensors=return_tensors) + return BatchFeature(data={**text_inputs, **preprocessed.data}) + + if medias is None: + medias = self._extract_medias_from_messages(messages) + updated_medias, video_prompts = self.preprocess_medias(medias) + preprocessed = self.media_processor.preprocess( + updated_medias, return_tensors=return_tensors) + + # Generate text if not provided + if text is None: + text = self.tokenizer.apply_chat_template(messages, **kwargs) + + text = self.update_raw_text(text, video_prompts) + + text_inputs = self.tokenizer(text, return_tensors=return_tensors) + return BatchFeature(data={**text_inputs, **preprocessed.data}) + + @staticmethod + def _extract_medias_from_messages(messages: list[dict]) -> list[dict]: + """ + Extract media items from messages in a single pass. + + This is an optimized version that processes messages only once. + Kept as internal method since external callers should use __call__. + """ + medias = [] + for msg in messages: + if msg['role'] != 'user' or not msg.get('content'): + continue + + for content_part in msg['content']: + if not isinstance(content_part, dict): + continue + + content_type = content_part.get('type') + if content_type in ['video_url', 'video']: + medias.append({ + 'type': 'video', + 'video': content_part['video_url']['url'], + 'first_frame_timestamp': 0.0 + }) + elif content_type in ['image_url', 'image']: + medias.append({ + 'type': 'image', + 'image': content_part['image_url'], + }) + return medias + + def apply_chat_template(self, messages, **kwargs): + return self.tokenizer.apply_chat_template(messages, **kwargs) + + def batch_decode(self, *args, **kwargs): + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + return ['input_ids', 'attention_mask', 'pixel_values', 'grid_thws'] diff --git a/kimi_k25_vision_processing.py b/kimi_k25_vision_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf3ab2f100f7c28a1f1e7295297e54b515d0b53 --- /dev/null +++ b/kimi_k25_vision_processing.py @@ -0,0 +1,251 @@ +"""Image processor class for Kimi-K2.5. +""" + +import json +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +from PIL import Image +from transformers.image_processing_utils import (BaseImageProcessor, + BatchFeature) +from transformers.utils import TensorType + +from .media_utils import (MediaInput, VideoChunkInput, _to_tensor, + ensure_media_type, get_video_meta, image_to_np, + navit_patchify, navit_resize_image, + navit_resize_video, normalize, + real_sample_fps_and_max_num_frames, timestamp_as_str) + +try: + from mecord import VideoReader +except ImportError: + VideoReader = None + + +def resampling(video_bytes: bytes, + sample_indices: list[int], + key_indices=None, + frame_time_info=None, + num_threads=4) -> str: + video = VideoReader(video_bytes, + num_threads=num_threads, + frame_time_info=frame_time_info, + key_indices=key_indices) + # extract target frames + frames = video[sample_indices] + frames = [Image.fromarray(frame) for frame in frames] + return frames + + +class KimiK25VisionProcessor(BaseImageProcessor): + model_type = "kimi_k25" + + def __init__( + self, + media_proc_cfg: dict, + **kwargs, + ): + super().__init__(**kwargs) + self.media_proc_cfg = media_proc_cfg + self.num_frames_per_chunk = media_proc_cfg[ + 'temporal_merge_kernel_size'] + + def media_tokens_calculator(self, media: MediaInput): + media = ensure_media_type(media) + ret = self.get_resize_config(media) + return ret['num_tokens'] + + @classmethod + def make_chunk_prompt(cls, timestamp_text: str) -> str: + return f"{timestamp_text}<|media_begin|>video<|media_content|><|media_pad|><|media_end|>" + + def split_video_chunks(self, + video_url: str | bytes) -> list[list[Image.Image]]: + # video_url should be base64 str or bytes + video_spec = get_video_meta(video_url) + sample_fps = min(self.media_proc_cfg['sample_fps'], video_spec.fps) + sampled_nframes = max( + round(video_spec.num_frames * sample_fps / video_spec.fps), 1) + frame_inds = np.linspace(0, video_spec.num_frames - 1, + sampled_nframes).round().astype(int) + frame_inds = frame_inds.tolist() + sampled_frame_ids = [] + temporal_merge_kernel_size = self.media_proc_cfg[ + "temporal_merge_kernel_size"] + num_chunks = 0 + chunk_timestamp = [] + for i in range(0, len(frame_inds), temporal_merge_kernel_size): + sampled_frame_ids.extend(frame_inds[i:i + + temporal_merge_kernel_size]) + start_time = frame_inds[i] / float(video_spec.fps) + timestamp_text = timestamp_as_str( + start_time, self.media_proc_cfg["timestamp_mode"]) + chunk_timestamp.append(timestamp_text) + num_chunks += 1 + + sampled_frames = resampling(video_url, sampled_frame_ids) + chunks = [] + for chunk_id in range(num_chunks): + chunk = sampled_frames[chunk_id * + temporal_merge_kernel_size:(chunk_id + 1) * + temporal_merge_kernel_size] + chunks.append( + VideoChunkInput(type="video_chunk", + video_chunk=chunk, + prompt=self.make_chunk_prompt( + chunk_timestamp[chunk_id]))) + return chunks + + def get_resize_config(self, media_input: MediaInput) -> dict: + if media_input['type'] == 'image': + w, h = media_input['image'].size + ret = navit_resize_image( + w, h, self.media_proc_cfg['patch_size'], + self.media_proc_cfg['merge_kernel_size'], + self.media_proc_cfg['in_patch_limit'], + self.media_proc_cfg['patch_limit_on_one_side'], + self.media_proc_cfg['fixed_output_tokens']) + return ret + elif media_input['type'] == 'video_chunk': + frame = media_input['video_chunk'][0] + width, height = frame.size + num_frames = len(media_input["video_chunk"]) + fps = 1.0 + + sample_fps, max_num_frames_each_video = real_sample_fps_and_max_num_frames( + media_input["type"], + self.media_proc_cfg['sample_fps'], + self.media_proc_cfg['max_num_frames_each_video'], + ) + + in_patch_limit_each_frame = self.media_proc_cfg[ + 'in_patch_limit_each_frame'] + if in_patch_limit_each_frame is None: + in_patch_limit_each_frame = self.media_proc_cfg[ + 'in_patch_limit'] + + ret = navit_resize_video( + width, + height, + num_frames, + fps, + sample_fps, + self.media_proc_cfg['patch_size'], + self.media_proc_cfg['merge_kernel_size'], + in_patch_limit_each_frame, + self.media_proc_cfg['patch_limit_on_one_side'], + self.media_proc_cfg['in_patch_limit_video'], + max_num_frames_each_video, + self.media_proc_cfg['fixed_output_tokens'], + ) + return ret + else: + raise ValueError("Unsupported type: {}".format( + media_input['type'])) + + def resize_image(self, image: Image.Image, new_width: int, new_height: int, + pad_width: int, pad_height: int) -> np.ndarray: + image_np = image_to_np(image, (new_width, new_height), "resize") + image_np = np.pad( + image_np, + ((0, pad_height), (0, pad_width), (0, 0)), + mode="constant", + constant_values=0, + ) + return image_np + + def preprocess( + self, + medias: list[MediaInput], + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + """ + Preprocess a atom vision input (images/video_chunk) into model-ready tensors. + + Args: + medias: List of MediaInput. + return_tensors: Desired output format ('pt', 'np', 'tf', or None). + + Returns: + BatchFeature containing 'pixel_values' and 'grid_thws' tensors. + """ + if not isinstance(medias, list): + medias = [medias] + if medias: + pixel_values = [] + for item in medias: + item = ensure_media_type(item) + resize_config = self.get_resize_config(item) + new_width, new_height, pad_width, pad_height = resize_config[ + 'new_width'], resize_config['new_height'], resize_config[ + 'pad_width'], resize_config['pad_height'] + if item['type'] == 'image': + image = item['image'] + image_np = self.resize_image(image, new_width, new_height, + pad_width, pad_height) + pixel_values.append(np.expand_dims(image_np, axis=0)) + elif item['type'] == 'video_chunk': + pixels = [] + for frame in item['video_chunk']: + frame_np = self.resize_image(frame, new_width, + new_height, pad_width, + pad_height) + pixels.append(frame_np) + pixel_values.append(np.stack(pixels, axis=0)) + else: + raise ValueError("Unsupported type: {}".format( + item['type'])) + normalized_pixel_values = [] + image_std_inv = 1.0 / np.array(self.media_proc_cfg['image_std']) + image_mean = np.array(self.media_proc_cfg['image_mean']) + for pixels in pixel_values: + pixels = normalize(pixels, image_mean, image_std_inv) + pixels_and_thw = navit_patchify( + pixels, + self.media_proc_cfg['patch_size'], + ) + normalized_pixel_values.append(pixels_and_thw) + + pixel_values = torch.cat([ + _to_tensor(pixel_value['pixel_values']) + for pixel_value in normalized_pixel_values + ]) + grid_thws = torch.cat([ + _to_tensor(pixel_value['grid_thw'], + dtype=torch.int64).unsqueeze(0) + for pixel_value in normalized_pixel_values + ]) + + data = { + 'pixel_values': pixel_values, + 'grid_thws': grid_thws, + } + + else: + data = {} + + return BatchFeature(data=data, tensor_type=return_tensors) + + def __repr__(self): + return f"KimiK25VisionProcessor(media_proc_cfg={self.media_proc_cfg})" + + def to_dict(self) -> Dict[str, Any]: + output = super().to_dict() + output["media_proc_cfg"] = self.media_proc_cfg + if "media_processor" in output: + del output["media_processor"] + return output + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs): + config = config_dict.copy() + media_proc_cfg = config.pop("media_proc_cfg", {}) + return cls(media_proc_cfg=media_proc_cfg, **config, **kwargs) + + def to_json_string(self): + dictionary = self.to_dict() + for key, value in dictionary.items(): + if hasattr(value, 'tolist'): + dictionary[key] = value.tolist() + return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" diff --git a/media_utils.py b/media_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8795e06f381700d6420798f82174e3f9647e9f89 --- /dev/null +++ b/media_utils.py @@ -0,0 +1,368 @@ +import base64 +import io +import math +import os +from datetime import datetime, timezone +from typing import List, Literal, Optional, TypedDict + +import numpy as np +from PIL import Image +from pydantic import BaseModel, Field + +try: + from mecord import VideoReader +except ImportError: + VideoReader = None + + +class VideoSpec(BaseModel): + media_type: str = Literal['video'] + height: int = Field(..., gt=0, description="video frame height") + width: int = Field(..., gt=0, description="video frame width") + num_frames: int = Field(..., gt=0, description="num frames") + fps: float = Field(..., gt=0, description="average fps") + + # optional, help to accelerate video reading + key_indices: list[int] = Field(None, description="key indices") + frame_time_info: dict = Field(None, description="frame time info") + + +class ImageInput(TypedDict): + type: Literal['image'] + image: Image.Image + + +class VideoChunkInput(TypedDict): + type: Literal['video_chunk'] + video_chunk: List[Image.Image] + prompt: Optional[str] = None + + +MediaInput = ImageInput | VideoChunkInput + + +def get_video_meta(video_src: bytes | str | os.PathLike, + accurate: bool = True) -> dict: + """Get the dimensions of a video.""" + if isinstance(video_src, os.PathLike): + video_src = str(video_src) + # if b64 string, decode to bytes + if isinstance(video_src, + str) and video_src.startswith('data:video/mp4;base64,'): + video_src = base64.b64decode(video_src.split(',')[1]) + video = VideoReader(video_src, auto_init=accurate, num_threads=1) + assert video.num_frames > 0, "Invalid video format." + assert video.original_width > 0 and video.original_height > 0, ( + "Invalid video format.") + assert video.avg_fps > 0, "Invalid video format." + return VideoSpec(media_type='video', + height=video.original_height, + width=video.original_width, + num_frames=video.num_frames, + fps=video.avg_fps, + key_indices=video.key_indices, + frame_time_info=video.frame_time_info) + + +def timestamp_as_str(timestamp: float, + timestamp_mode: str = "hh:mm:ss.fff") -> str: + """Convert a timestamp to a string in the format of HH:MM:SS.mmm.""" + if timestamp_mode == "hh:mm:ss.fff": + return (datetime.fromtimestamp(timestamp, + tz=timezone.utc).strftime("%H:%M:%S") + + f".{int((timestamp % 1) * 1000):03d}") + elif timestamp_mode == "mm:ss.fff": + return (datetime.fromtimestamp(timestamp, + tz=timezone.utc).strftime("%M:%S") + + f".{int((timestamp % 1) * 1000):03d}") + elif timestamp_mode == "mm:ss": + return datetime.fromtimestamp(timestamp, + tz=timezone.utc).strftime("%M:%S") + else: + raise ValueError(f"Invalid timestamp mode: {timestamp_mode}") + + +def navit_resize_image( + width: int, + height: int, + patch_size: int, + merge_kernel_size: int, + in_patch_limit: int, + patch_limit_on_one_side: int, + fixed_output_tokens: int | None, +): + # Apply the patch limits. + s1 = math.sqrt( + in_patch_limit / + (max(1.0, width // patch_size) * max(1.0, height // patch_size))) + s2 = patch_limit_on_one_side * patch_size / width + s3 = patch_limit_on_one_side * patch_size / height + scale = min(1.0, s1, s2, s3) + new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale)) + new_w = min(new_w, patch_limit_on_one_side * patch_size) + new_h = min(new_h, patch_limit_on_one_side * patch_size) + + # Calculate the padding to make the height and width divisible by the merge kernel size and patch size. + factor = merge_kernel_size * patch_size + + pad_height = (factor - new_h % factor) % factor + pad_width = (factor - new_w % factor) % factor + + if fixed_output_tokens is not None: + num_tokens = fixed_output_tokens + else: + # Calculate new dimensions after padding and patching + token_height = (new_h + pad_height) // factor + token_width = (new_w + pad_width) // factor + + assert token_height * merge_kernel_size <= patch_limit_on_one_side, ( + f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}" + ) + assert token_width * merge_kernel_size <= patch_limit_on_one_side, ( + f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}" + ) + + num_tokens = token_height * token_width + return { + "num_tokens": num_tokens, + "new_width": new_w, + "new_height": new_h, + "pad_width": pad_width, + "pad_height": pad_height, + "sampled_nframes": 1, + } + + +def navit_resize_video( + width: int, + height: int, + nframes: int, + avg_fps: float, + sample_fps: float, + patch_size: int, + merge_kernel_size: int, + in_patch_limit_each_frame: int, + patch_limit_on_one_side: int, + in_patch_limit_total: int | None, + max_num_frames_each_video: int | None, + fixed_output_tokens_each_frame: int | None, +): + sample_fps = min(sample_fps, avg_fps) + # Calculate the number of frames to sample based on target FPS + sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1) + if max_num_frames_each_video is not None: + sampled_nframes = min(sampled_nframes, max_num_frames_each_video) + + if in_patch_limit_total is not None: + in_patch_limit_each_frame = min( + round(in_patch_limit_total / sampled_nframes), + in_patch_limit_each_frame) + + ret = navit_resize_image( + width, + height, + patch_size, + merge_kernel_size, + in_patch_limit_each_frame, + patch_limit_on_one_side, + fixed_output_tokens_each_frame, + ) + ret["sampled_nframes"] = sampled_nframes + return ret + + +def real_sample_fps_and_max_num_frames( + type_name: Literal["video", "video_chunk"], + sample_fps: float, + max_num_frames_each_video: int | None, +) -> tuple[int, int | None]: + if type_name == "video": + return sample_fps, max_num_frames_each_video + elif type_name == "video_chunk": + max_num_frames_each_video = None + sample_fps = math.inf + return sample_fps, max_num_frames_each_video + else: + return math.inf, None + + +def _to_pil(data: str | bytes): + if isinstance(data, Image.Image): + + return data.convert("RGB") + elif isinstance(data, str): + if data.startswith("data:"): + raw_base64 = data.split(",")[1] + return Image.open(io.BytesIO( + base64.b64decode(raw_base64))).convert("RGB") + else: + return Image.open(data).convert("RGB") + elif isinstance(data, bytes): + return Image.open(io.BytesIO(data)).convert("RGB") + else: + raise ValueError(f"Unsupported data type: {type(data)}") + + +def ensure_media_type(media: MediaInput) -> MediaInput: + if media['type'] == 'image': + media['image'] = _to_pil(media['image']) + return media + elif media['type'] == 'video_chunk': + media['video_chunk'] = [ + _to_pil(frame) for frame in media['video_chunk'] + ] + return media + else: + raise ValueError(f"Unsupported media type: {media['type']}") + + +def image_to_np( + image: Image.Image, + resize_to: tuple[int, int] | None = None, + mode: str = "resize", + raise_error_for_ill_resize: bool = True, +) -> np.ndarray: + """Convert an image to a numpy array. + + Args: + content: The image to convert. + resize_to: The size to resize the image to. + mode: The mode to resize the image to. + raise_error_for_ill_resize: Whether to raise an error for ill-sized resize. + + Returns: + A numpy array. + """ + assert isinstance(image, Image.Image), "image must be a PIL Image" + if resize_to is not None: + if mode == "resize": + image = image.resize(resize_to, resample=Image.Resampling.BICUBIC) + + elif mode == "rescale_and_pad_to_center": + scale = min(resize_to[0] / image.width, + resize_to[1] / image.height, 1.0) + new_width = round(image.width * scale) + new_height = round(image.height * scale) + if new_width == 0 or new_height == 0: + if raise_error_for_ill_resize: + raise ValueError( + f"Invalid resize to: {resize_to}, from image size: {image.size}" + ) + else: + return np.zeros((resize_to[1], resize_to[0], 3), + dtype=np.uint8) + + image = image.resize((new_width, new_height), + resample=Image.Resampling.BICUBIC) + padding_left = (resize_to[0] - new_width) // 2 + padding_right = resize_to[0] - new_width - padding_left + padding_top = (resize_to[1] - new_height) // 2 + padding_bottom = resize_to[1] - new_height - padding_top + image = np.asarray(image) + image = np.pad( + image, + ((padding_top, padding_bottom), (padding_left, padding_right), + (0, 0)), + mode="constant", + constant_values=0, + ) + assert image.shape == (resize_to[1], resize_to[0], 3) + + elif mode == "rescale_and_pad_to_rightbottom": + scale = min(resize_to[0] / image.width, + resize_to[1] / image.height, 1.0) + new_width = round(image.width * scale) + new_height = round(image.height * scale) + if new_width == 0 or new_height == 0: + if raise_error_for_ill_resize: + raise ValueError( + f"Invalid resize to: {resize_to}, from image size: {image.size}" + ) + else: + return np.zeros((resize_to[1], resize_to[0], 3), + dtype=np.uint8) + + image = image.resize((new_width, new_height), + resample=Image.Resampling.BICUBIC) + padding_right = resize_to[0] - new_width + padding_bottom = resize_to[1] - new_height + image = np.asarray(image) + image = np.pad( + image, + ((0, padding_bottom), (0, padding_right), (0, 0)), + mode="constant", + constant_values=0, + ) + assert image.shape == (resize_to[1], resize_to[0], 3) + + else: + raise ValueError(f"Invalid mode: {mode}") + + if isinstance(image, Image.Image): + return np.asarray(image) + else: + return image + + +def navit_patchify(pixel_values: np.ndarray, + patch_size: int) -> dict[str, np.ndarray]: + """Reshape the pixel values to a navit shape. + + Args: + pixel_values: np.ndarray, shape (t, h, w, c) + patch_size: int + + Returns: + dict[str, np.ndarray] + - patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size) + - grid_thw: np.ndarray, (t, h//patch_size, w//patch_size) + """ + T, H, W, C = pixel_values.shape + assert C == 3, "pixel_values must have 3 channels" + + patches = pixel_values.reshape(T, H // patch_size, patch_size, + W // patch_size, patch_size, C) + # (T, H//patch_size, W//patch_size, C, patch_size, patch_size) + patches = patches.transpose(0, 1, 3, 5, 2, 4) + patches = patches.reshape(-1, C, patch_size, patch_size) + grid_thw = np.array([T, H // patch_size, W // patch_size]) + return {"pixel_values": patches, "grid_thw": grid_thw} + + +def normalize(x: np.ndarray, + mean, + std_inv, + pixels_dtype: np.dtype = np.float32) -> np.ndarray: + """Normalize the image. + + Args: + x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255]. + mean: The mean of the image. + std_inv: The inverse of the std of the image. + pixels_dtype: The dtype of the image. + Returns: + The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype. + """ + x = (x / 255.0).astype(pixels_dtype) + x -= mean + x *= std_inv + return x + + +def _to_tensor(data, **kwargs): + import torch + + if isinstance(data, np.ndarray): + return torch.from_numpy(data).to(**kwargs) + elif isinstance(data, torch.Tensor): + return data.to(**kwargs) + elif isinstance(data, list): + return [_to_tensor(item, **kwargs) for item in data] + elif isinstance(data, tuple): + return tuple(_to_tensor(item, **kwargs) for item in data) + elif isinstance(data, dict): + return {k: _to_tensor(v, **kwargs) for k, v in data.items()} + elif data is None: + return None + else: + raise ValueError(f"Unsupported data type: {type(data)}") diff --git a/model-00001-of-000064.safetensors b/model-00001-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..809b7aaebbd49d29c3f749c978a4d5d41af7e822 --- /dev/null +++ b/model-00001-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18daf53a15070b9c70d8bb63420dbd39764af9118af67982eeb60749f5453233 +size 995001888 diff --git a/model-00002-of-000064.safetensors b/model-00002-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..3cb514dfe6cd50948b2b5bdba1d109255f2b3eb1 --- /dev/null +++ b/model-00002-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e017c948926333558df1a9637ff052c663378a70afbc1bb5b20528b8b5a501aa +size 9809047464 diff --git a/model-00003-of-000064.safetensors b/model-00003-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..f8d6214c1b62a1c61becbeeff5c5be7eb604e075 --- /dev/null +++ b/model-00003-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c74880b04dc208397d7182c4692ca67644ea732be216ac4b23241db389f86886 +size 9809047464 diff --git a/model-00004-of-000064.safetensors b/model-00004-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..1011088c8df7ed9d43275fc8b9a5f979a0a82b4f --- /dev/null +++ b/model-00004-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17faa37dbf12b17eb3cb94bc7d9b85db9f7d68411bdecec84224425b25b22fca +size 9809047464 diff --git a/model-00005-of-000064.safetensors b/model-00005-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..404d054795db6680acb4e055a8a24dd619334c6a --- /dev/null +++ b/model-00005-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a97ff66ec88a9e02a129cb3dc75b2c3cb4fc4ccf698777042253504be372b52f +size 9809047464 diff --git a/model-00006-of-000064.safetensors b/model-00006-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..c1dad5c90fb9fec8bdaea35394abd4dc211230aa --- /dev/null +++ b/model-00006-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da5cfbfe78d1c740c3cbe8ba49d84712d4328be72d3389dad54361d480cd3148 +size 9809047464 diff --git a/model-00007-of-000064.safetensors b/model-00007-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..2db6670c9ef0c897fa833ed1e26662cea67d5f70 --- /dev/null +++ b/model-00007-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65b0cd81a5c8320fdfbba91f3530bf35da9befe7a776da1f4d5485083b527d8a +size 9809047464 diff --git a/model-00008-of-000064.safetensors b/model-00008-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..6752ac30894c1cca842d03f404d5d83b56516da0 --- /dev/null +++ b/model-00008-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da91fa6a5e61a3dc36a2a9e46e0be066ce67ae78a1f2dcdce04511c38e30fb5d +size 9809047464 diff --git a/model-00009-of-000064.safetensors b/model-00009-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..ce1aa871f207e66cfa13c221b78ce2e1a50fbfb7 --- /dev/null +++ b/model-00009-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e4f08c59cee6cc26a92cfb5f36952511b171d5db527d1d78f84f7a081ab78599 +size 9809047464 diff --git a/model-00010-of-000064.safetensors b/model-00010-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..46d5ba31b98f2549f1fa4486016d39d08476f7a5 --- /dev/null +++ b/model-00010-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bdae095b9c721830784435564dada40a63c2c4b434cba4d81dbc0b1173e3d75 +size 9809047464 diff --git a/model-00011-of-000064.safetensors b/model-00011-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..a2786895e4eb1302fc6d057e033e0e31ba4a7a22 --- /dev/null +++ b/model-00011-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:899dbf01f498fa71c4e97833caac73c7fbb3e197a237f2f7d9fea9c5262afdc3 +size 9809050936 diff --git a/model-00012-of-000064.safetensors b/model-00012-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..5915e61887affe2b0812b72e781c3b242cff7314 --- /dev/null +++ b/model-00012-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72be28902ba87b0a1ac24ebb3b36ee89850cc7f13fbbb51e8b1dba743c7eb171 +size 9809050936 diff --git a/model-00013-of-000064.safetensors b/model-00013-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..a75b0088cf21c85ba43e2b7b24d1f5de53566038 --- /dev/null +++ b/model-00013-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f375d01e1b1230cb0bbfc3c6ce6a81c26a578b633dc43e24956610c84205d688 +size 9809050936 diff --git a/model-00014-of-000064.safetensors b/model-00014-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..6cd0b4c100fb1c2d46e518c346bf5415596cc2a1 --- /dev/null +++ b/model-00014-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b341417f09964cf477c60fd431ecd6e454442e973d41a17038004340ce8ad83e +size 9809050936 diff --git a/model-00015-of-000064.safetensors b/model-00015-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..8f94dfecdbf91548d83689e02df6d20c8000a95b --- /dev/null +++ b/model-00015-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9534ee2411e8c05451e19b7234ceddcd33271795b782d910baa9ebb96738d2f9 +size 9809050936 diff --git a/model-00016-of-000064.safetensors b/model-00016-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..c624966151e743fcbe6230bee650897320d0a57f --- /dev/null +++ b/model-00016-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46cd89549baf6099c1629a1a023b0a6886d5d115f9afc97ffb68390339967eb7 +size 9809050936 diff --git a/model-00017-of-000064.safetensors b/model-00017-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..827b47317aac25d14236b2276e626f5b21f5cce9 --- /dev/null +++ b/model-00017-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b653d852e4d30f8eceb0800ae1a5b15c02fd203fae04c097aebb8fc99a0ef03 +size 9809050936 diff --git a/model-00018-of-000064.safetensors b/model-00018-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..6a3f7e0cee9b508679e95ee4e11dc540f7232c29 --- /dev/null +++ b/model-00018-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd82fedc4da324d1be93db896d0b30ccc02ffd4d267bd1b64875c5cc0d7cee22 +size 9809050936 diff --git a/model-00019-of-000064.safetensors b/model-00019-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..228b260293d6ed64bc4e0e349ae6b39c171b3938 --- /dev/null +++ b/model-00019-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:835c2cce752c8f7030d04fdd11a27311efc489b1558dee9d646d135d3cd82044 +size 9809050936 diff --git a/model-00020-of-000064.safetensors b/model-00020-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..bbe0c6cbb1f40d0b8818c7ced86dbdf6581e8dfa --- /dev/null +++ b/model-00020-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26a3ec988f47120f07938afbf7ce3a27dbeeaf955e816ccf202ebf3751833438 +size 9809050936 diff --git a/model-00021-of-000064.safetensors b/model-00021-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..115375f04ac0408bf6eaefb5a8b52a3ab9a3a8c2 --- /dev/null +++ b/model-00021-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02376f1b7e16343f0e7ea4d847c85e33da30295499c25145fbd4998cb44fdcff +size 9809050936 diff --git a/model-00022-of-000064.safetensors b/model-00022-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..57e90f76bfb569e5c9a151b043f6d320fe6873a4 --- /dev/null +++ b/model-00022-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a177b349938d277b2ff17d4593981343c7a66a6445b964fd596551dc325dd889 +size 9809050936 diff --git a/model-00023-of-000064.safetensors b/model-00023-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..bee4b03197546f787d755da2a59d428918acc92e --- /dev/null +++ b/model-00023-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ecbbd2355a4b74d425a1ab847de07d603a0f23c1eb5615563e72f224335022d7 +size 9809050936 diff --git a/model-00024-of-000064.safetensors b/model-00024-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..a47b0f110555eed763a22492f7a9edae816c4b55 --- /dev/null +++ b/model-00024-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:505f7178c7e3d089aa51d3eb83b2133e379ae75e25b00ebc80793c4affbd90e3 +size 9809050936 diff --git a/model-00025-of-000064.safetensors b/model-00025-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..325678d4d56b498c3d2f6316ea85257e6993c9a1 --- /dev/null +++ b/model-00025-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9240b5987ffbe8b47f0668db0b7608916e1648bc9a3a0b55f6079c222e430a83 +size 9809050936 diff --git a/model-00026-of-000064.safetensors b/model-00026-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..4942622fa799d760299eee6ceec658401802211b --- /dev/null +++ b/model-00026-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4886aa2bf3e83fbb83adee7fde183b9476820dafcbaade99a87b95c8946e5b78 +size 9809050936 diff --git a/model-00027-of-000064.safetensors b/model-00027-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..6a64f025c652d37165c2fa87d32c052b58d13559 --- /dev/null +++ b/model-00027-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06e78ba422ca3d7188d701d6cbfef9ca79a94e56c572894ac63666cc90336a0c +size 9809050936 diff --git a/model-00028-of-000064.safetensors b/model-00028-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..ea8608f970a2ee3e8d3d770e02865437461886cc --- /dev/null +++ b/model-00028-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4235b1e3d587a2ebbdcf111eb231c1d66a891ef0c8e6bc91214e4ebccbfa499f +size 9809050936 diff --git a/model-00029-of-000064.safetensors b/model-00029-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..0ed25f79fabb2661d069eae5d98915870d3f5498 --- /dev/null +++ b/model-00029-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e37a630cd0f516cf82442de3db77f1f4ec7f0fa7b37a429c51dd17c26fd6b02 +size 9809050936 diff --git a/model-00030-of-000064.safetensors b/model-00030-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..9274a4b79445f7d081ea6cba1536f578b9e3a9f0 --- /dev/null +++ b/model-00030-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04da30f98f4c57c6a821090be82d0d8e3af92ccb3edefb7cabed893d96ee0a04 +size 9809050936 diff --git a/model-00031-of-000064.safetensors b/model-00031-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..8ee9f60b58c16e569404fa62ec1c4f0c19a3d640 --- /dev/null +++ b/model-00031-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fdf3def720479758dfb38f53b6cf97d4d326fd39f215b2d1c65699423d3af2ad +size 9809050936 diff --git a/model-00032-of-000064.safetensors b/model-00032-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..047ec2922d68e8f8f4774a411d6af9559174e0b9 --- /dev/null +++ b/model-00032-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52534a0a523aaf64c061297aa1efe3b6cb5e42a0af422f14933c159c92ba9e49 +size 9809050936 diff --git a/model-00033-of-000064.safetensors b/model-00033-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..b25d5a7ac303c0a6e6c45c864209c06879f22b56 --- /dev/null +++ b/model-00033-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b8438e1e9a6573f403c03defa00aa0649f716b5da098beb422827834e889ea8 +size 9809050936 diff --git a/model-00034-of-000064.safetensors b/model-00034-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..1e2c822f5a6de011fafb4eb2f0a5e1d5e53101ec --- /dev/null +++ b/model-00034-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf6a7b9cf582a34c663e945ead490b13f3f6a07a00ff247c29afc4bf908960e3 +size 9809050936 diff --git a/model-00035-of-000064.safetensors b/model-00035-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..41fdf8c5beb0e17a412889a50f7ceca2a397a408 --- /dev/null +++ b/model-00035-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe6cd19a160352d14d6f28bb66dd7d9c06f1e9350025645def52fa979b77e004 +size 9809050936 diff --git a/model-00036-of-000064.safetensors b/model-00036-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..e562b19af5b5ec6c20054f405918c06bb74e7457 --- /dev/null +++ b/model-00036-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5487f8977f5ae403c45f36c34add86a8616da4bb8fa4f7dbccc9a0c156dcc1d +size 9809050936 diff --git a/model-00037-of-000064.safetensors b/model-00037-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..f591e2387219e3af33b819fde1f229a0c4c4a317 --- /dev/null +++ b/model-00037-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4236900d679c2a9c134e71e080b905fa877a53648ee7da7d36d3bf0b1f5586bb +size 9809050936 diff --git a/model-00038-of-000064.safetensors b/model-00038-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..cf23861d6893097ebe2fcb2ae58f0060fde10205 --- /dev/null +++ b/model-00038-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1c15627438fcdd2c905aa78a682f51a26fbdc50403bfa74601be6ee1550da3b +size 9809050936 diff --git a/model-00039-of-000064.safetensors b/model-00039-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..34ebc9b80a48cf31772c4cd5e912b09a01d27050 --- /dev/null +++ b/model-00039-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67cbaa9ecb8dc2dfd53c396927439bfd6a8a7503a49997615004fe5df0138e20 +size 9809050936 diff --git a/model-00040-of-000064.safetensors b/model-00040-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..c861c2f52c3df7298fd466ddf34054b2fa1b20d2 --- /dev/null +++ b/model-00040-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76b1c63b64a234b2631566f983f5775393ff3d8cbb6d8828fe9d168a039873ea +size 9809050936 diff --git a/model-00041-of-000064.safetensors b/model-00041-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..bb000aced184f94e91d85fe7c1bddee4d1fe4d83 --- /dev/null +++ b/model-00041-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7e27bd85a6de53b3098a5ebc83e47dae99ce0d207dca130681e2ca1f95ac649 +size 9809050936 diff --git a/model-00042-of-000064.safetensors b/model-00042-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..6bfab17dc0d79ee92117e1babee9c727597938aa --- /dev/null +++ b/model-00042-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67a8e121c665e0ad9fe9e72406d4560cc513b6a3df85ab05b8063e26cedc3f0b +size 9809050936 diff --git a/model-00043-of-000064.safetensors b/model-00043-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..42dd6ef5f8398964ca9960e06b22e5006f8f7845 --- /dev/null +++ b/model-00043-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:743f85f3f9d57e10432ab2961728385f5cb8ea8035e14b4fde83086b062b1afb +size 9809050936 diff --git a/model-00044-of-000064.safetensors b/model-00044-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..2e7141c7d6450f7b5472b160f43ad5f48dfde56d --- /dev/null +++ b/model-00044-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ce242860d1d06de3c8d9077817e83e48841daf9d85cacb2ec19fbff04c69161 +size 9809050936 diff --git a/model-00045-of-000064.safetensors b/model-00045-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..0bad12d4f8aa53eeb3bb0719e21db850a40f54e2 --- /dev/null +++ b/model-00045-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59b8afda0898ca3dd0ea382e941d05d1624061716f83b296f456a6e27db43f49 +size 9809050936 diff --git a/model-00046-of-000064.safetensors b/model-00046-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..b5902e577cc448be6561266539759cda5b0f28f1 --- /dev/null +++ b/model-00046-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e6118090a8738a46430bae56464855ea0adf2ab2f4752b63c96a53c88f3b74c +size 9809050936 diff --git a/model-00047-of-000064.safetensors b/model-00047-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..0816b5795d80d08f7cff45ec8de82fe94773237e --- /dev/null +++ b/model-00047-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:953fb907ef1c49724b402c0f420ce3007eac3f20d4b2ecf286012b4cde671cb8 +size 9809050936 diff --git a/model-00048-of-000064.safetensors b/model-00048-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..4276e2798ae51fc60b18832fd542f8357de34422 --- /dev/null +++ b/model-00048-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:587db1e460105d0911601db58c115cb8b6d136b2b327426127dc8a45ff39c5e8 +size 9809050936 diff --git a/model-00049-of-000064.safetensors b/model-00049-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..194b43acc48c3e575e7ac2b912b62bf353ffac3d --- /dev/null +++ b/model-00049-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a93c151606f2853f398eecd708da5a2c8975600bf319f6b7f1bade66f83a8ca9 +size 9809050936 diff --git a/model-00050-of-000064.safetensors b/model-00050-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..47c04637d871aab245be7c568b3bb7b1ddf26380 --- /dev/null +++ b/model-00050-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a5be161b2f3af9e1f299588d4ef898a174504324aa163398c0a510d8cf735eb +size 9809050936 diff --git a/model-00051-of-000064.safetensors b/model-00051-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..6bbed9a8105e8d921048e36c9a9be117e2b4f031 --- /dev/null +++ b/model-00051-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d042fbe87dc47857f897b7491b2a965207be73d7477e5a1d8c46d5f25441615e +size 9809050936 diff --git a/model-00052-of-000064.safetensors b/model-00052-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..f94261034a40910aa8509156752a7a63b2cda871 --- /dev/null +++ b/model-00052-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eaddbb07f16c5567ca9b94ef8585124188197877ce7c11b3c9818431da884adc +size 9809050936 diff --git a/model-00053-of-000064.safetensors b/model-00053-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..e222b2504e87d2c954ae1300fcbbaf548c3a14fb --- /dev/null +++ b/model-00053-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87a86769c0b7f8bc5f260ddedf35f85c4fc73db5524d025e9241adcb633bba15 +size 9809050936 diff --git a/model-00054-of-000064.safetensors b/model-00054-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..28acbc75350b7cdf3b6a90d31956d229657eadc7 --- /dev/null +++ b/model-00054-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:803a24c512c5e2b0cd336323720be011df58036e176d6d63f7f1976ec8e533ab +size 9809050936 diff --git a/model-00055-of-000064.safetensors b/model-00055-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..4be9ee28df052abdffabf7268fb5861c34dc0588 --- /dev/null +++ b/model-00055-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fcb683f1c9170fc6a661c569dc18163934e2629c0d25f278e4bc73f2ad09bfbe +size 9809050936 diff --git a/model-00056-of-000064.safetensors b/model-00056-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..7a9141aded7f744c44dbe3931153659d2b5c96f3 --- /dev/null +++ b/model-00056-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ef1f8d5bac88618ad5104a0ac30c268eac98a2add384c0bd31ec82cf8c3d5bd +size 9809050936 diff --git a/model-00057-of-000064.safetensors b/model-00057-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..7de702056635f305af10cc13f89fca48343b4535 --- /dev/null +++ b/model-00057-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db51daa87c01b891ab72cf41acc01be732a5f8108afa32c7a50e8e1bba57e263 +size 9809050936 diff --git a/model-00058-of-000064.safetensors b/model-00058-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..d047d9108a185ef4467957e3d8c4a11a69120e5b --- /dev/null +++ b/model-00058-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07998dcd30b032da866526de21463c7d4b54c16818ccac2d0faf9609b8d3ec8d +size 9809050936 diff --git a/model-00059-of-000064.safetensors b/model-00059-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..af5c1fdbe2a9c5c4896cf6018121c691206c0df2 --- /dev/null +++ b/model-00059-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7c1b87ab8a9c0df0e85acb6af5c5b7096ac434aaf2df17ca05b0b12ccd5a8ea +size 9809050936 diff --git a/model-00060-of-000064.safetensors b/model-00060-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..2f1cfdbcf528cfb3d48bc59d96ff679dc3fdab9b --- /dev/null +++ b/model-00060-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d328618ace908cc752a8197591157a6ab060205f95b222e11bb9c57d8d906802 +size 9809050936 diff --git a/model-00061-of-000064.safetensors b/model-00061-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..cf54aa8d346858addc7352e8e8152de7d1243398 --- /dev/null +++ b/model-00061-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b10e67a118315d38fa740093f4c6f82f468eb7f2881e08639c74230927dfdc51 +size 9809050936 diff --git a/model-00062-of-000064.safetensors b/model-00062-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..199c4a87f31122cbb7644df59cf83d0e9f0ce5fe --- /dev/null +++ b/model-00062-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60f6e14a3f89fb81a1565de92417606379ed6f4aeaafff2ccba679a53df8c77f +size 4697635160 diff --git a/model-00063-of-000064.safetensors b/model-00063-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..ce0fd2cbb8a4d8666f5d7680a3fde21d9edcbcb6 --- /dev/null +++ b/model-00063-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0eb212be70e29ade1a3de5fdc24bff35aaee507dfba745ffb74febf2b05395e6 +size 108556344 diff --git a/model-00064-of-000064.safetensors b/model-00064-of-000064.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..9645e6a5ecab068cf768fa0e52999f1e6a4f9799 --- /dev/null +++ b/model-00064-of-000064.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a8bc2d71c92fce81cbda9da8a33db470eb1f5eb899e87c03c0973cba34ddf8f +size 833769904 diff --git a/model.safetensors.index.json b/model.safetensors.index.json new file mode 100644 index 0000000000000000000000000000000000000000..9a11548c9004aa9b3c0e8fa6b03f2a187c3400d1 --- /dev/null +++ b/model.safetensors.index.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdba19b127c4d1dc57dc3b6f3366c10739c7e7f13baf3f5424b556469a4dbc1b +size 23597438 diff --git a/modeling_deepseek.py b/modeling_deepseek.py new file mode 100644 index 0000000000000000000000000000000000000000..707925a7f440d22c1fde9ea13676b7ee924d6e21 --- /dev/null +++ b/modeling_deepseek.py @@ -0,0 +1,1808 @@ +# coding=utf-8 +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import \ + _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import (ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13) +from transformers.utils import (add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, + replace_return_docstrings) +from transformers.utils.import_utils import is_torch_fx_available + +from .configuration_deepseek import DeepseekV3Config + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input # noqa + from flash_attn.bert_padding import index_first_axis, unpad_input + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap( + _prepare_4d_causal_attention_mask) + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeepseekV3Config" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# code modified from transformers 4.48.3 to amend breaks in newer transformers versions +def get_usable_length(past_key_value, + new_seq_length: int, + layer_idx: Optional[int] = 0) -> int: + max_length = past_key_value.get_max_cache_shape() + previous_seq_length = past_key_value.get_seq_length(layer_idx) + if max_length is not None and max_length > 0 and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + +class DeepseekV3RMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm) + + +class DeepseekV3RotaryEmbedding(nn.Module): + + def __init__(self, + dim, + max_position_embeddings=2048, + base=10000, + device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base**( + torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, + device=device, + dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", + emb.cos().to(dtype), + persistent=False) + self.register_buffer("sin_cached", + emb.sin().to(dtype), + persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, + device=x.device, + dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, + device=device, + dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", + emb.cos().to(dtype), + persistent=False) + self.register_buffer("sin_cached", + emb.sin().to(dtype), + persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ((self.scaling_factor * seq_len / + self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.dim / + (self.dim - 2)) + inv_freq = 1.0 / (base**( + torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, + device=device, + dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", + emb.cos().to(dtype), + persistent=False) + self.register_buffer("sin_cached", + emb.sin().to(dtype), + persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim(num_rotations, + dim, + base=10000, + max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range(low_rot, + high_rot, + dim, + base=10000, + max_position_embeddings=2048): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freq_inter = 1.0 / (self.scaling_factor * self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32) + inv_freq = freq_inter * (1 - + inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) / + yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), + persistent=False) + self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), + persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV3MLP(nn.Module): + + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = (config.intermediate_size if intermediate_size + is None else intermediate_size) + + self.gate_proj = nn.Linear(self.hidden_size, + self.intermediate_size, + bias=False) + self.up_proj = nn.Linear(self.hidden_size, + self.intermediate_size, + bias=False) + self.down_proj = nn.Linear(self.intermediate_size, + self.hidden_size, + bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj( + self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim))) + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter( + torch.empty((self.n_routed_experts))) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states.type(torch.float32), + self.weight.type(torch.float32), None) + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + ### select top-k experts + if self.topk_method == "noaux_tc": + assert not self.training + scores_for_choice = scores.view( + bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) + group_scores = (scores_for_choice.view( + bsz * seq_len, self.n_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group] + group_idx = torch.topk(group_scores, + k=self.topk_group, + dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = (group_mask.unsqueeze(-1).expand( + bsz * seq_len, self.n_group, + self.n_routed_experts // self.n_group).reshape( + bsz * seq_len, -1)) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), + 0.0) # [n, e] + _, topk_idx = torch.topk(tmp_scores, + k=self.top_k, + dim=-1, + sorted=False) + topk_weight = scores.gather(1, topk_idx) + else: + raise NotImplementedError( + f"insupportable TopK function for MoE gating: {self.topk_method}" + ) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor + + return topk_idx, topk_weight + + +class DeepseekV3MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList([ + (DeepseekV3MLP(config, + intermediate_size=config.moe_intermediate_size) + if i >= self.ep_rank * self.experts_per_rank + and i < (self.ep_rank + 1) * self.experts_per_rank else None) + for i in range(config.n_routed_experts) + ]) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList([ + DeepseekV3MLP(config, + intermediate_size=config.moe_intermediate_size) + for i in range(config.n_routed_experts) + ]) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=intermediate_size) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if not self.training: + y = self.moe_infer(hidden_states, topk_idx, + topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, + -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0]) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = (tokens_per_expert_group.view( + self.ep_size, -1).sum(1).cpu().numpy().tolist()) + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), + sorted_tokens.shape[1]) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view( + self.ep_size, self.experts_per_rank).sum(dim=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0], ), + dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s:s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, + dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = (new_x.view( + *topk_ids.shape, -1).type(topk_weight.dtype).mul_( + topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype)) + return final_out + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, + None, :, :].expand(batch, + num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, + head_dim) + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3 +class DeepseekV3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, + config: DeepseekV3Config, + layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class.") + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear(self.hidden_size, + self.num_heads * self.q_head_dim, + bias=False) + else: + self.q_a_proj = nn.Linear(self.hidden_size, + config.q_lora_rank, + bias=config.attention_bias) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, + self.num_heads * self.q_head_dim, + bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads * + (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim**(-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV3RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return (tensor.view(bsz, seq_len, self.num_heads, + self.v_head_dim).transpose(1, 2).contiguous()) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view( + bsz, q_len, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2)) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index.") + kv_seq_len += get_usable_length(past_key_value, kv_seq_len, + self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, + self.q_head_dim) + query_states[:, :, :, :self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim:] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, + self.q_head_dim) + key_states[:, :, :, :self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * + self.softmax_scale) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}") + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, + self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3 +class DeepseekV3FlashAttention2(DeepseekV3Attention): + """ + DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10( + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + # DeepseekV3FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view( + bsz, q_len, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2)) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += get_usable_length(past_key_value, kv_seq_len, + self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, + self.q_head_dim) + query_states[:, :, :, :self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim:] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, + self.q_head_dim) + key_states[:, :, :, :self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, + [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV3RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = (self.q_proj.weight.dtype if self.q_lora_rank + is None else self.q_a_proj.weight.dtype) + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}.") + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, :self.v_head_dim] + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * + self.v_head_dim).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input(query_states, key_states, value_states, + attention_mask, query_length) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": DeepseekV3Attention, + "flash_attention_2": DeepseekV3FlashAttention2, +} + + +class DeepseekV3DecoderLayer(nn.Module): + + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx) + + self.mlp = (DeepseekV3MoE(config) if + (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0) else + DeepseekV3MLP(config)) + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (self_attn_weights, ) + + if use_cache: + outputs += (present_key_value, ) + + return outputs + + +DeepseekV3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3PreTrainedModel(PreTrainedModel): + config_class = DeepseekV3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV3DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DeepseekV3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3Model(DeepseekV3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`] + + Args: + config: DeepseekV3Config + """ + + def __init__(self, config: DeepseekV3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, + self.padding_idx) + self.layers = nn.ModuleList([ + DeepseekV3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = DeepseekV3RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = (output_attentions if output_attentions is not None + else self.config.output_attentions) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = (return_dict if return_dict is not None else + self.config.use_return_dict) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError( + "You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache( + past_key_values) + past_key_values_length = get_usable_length(past_key_values, + seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = (attention_mask if + (attention_mask is not None + and 0 in attention_mask) else None) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[ + 2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = None + if use_cache: + next_cache = (next_decoder_cache.to_legacy_cache() + if use_legacy_cache else next_decoder_cache) + if not return_dict: + return tuple( + v for v in + [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, + config.vocab_size, + bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, + config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM + + >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = (output_attentions if output_attentions is not None + else self.config.output_attentions) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = (return_dict if return_dict is not None else + self.config.use_return_dict) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits, ) + outputs[1:] + return (loss, ) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + # seen_tokens 可能在某些 transformers 版本中不存在,使用 getattr 安全访问 + past_length = getattr(past_key_values, 'seen_tokens', + cache_length) + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if (attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1]): + input_ids = input_ids[:, -(attention_mask.shape[1] - + past_length):] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if (max_cache_length is not None and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update({ + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + }) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past), ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = (return_dict if return_dict is not None else + self.config.use_return_dict) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq( + input_ids, self.config.pad_token_id).int().argmax(-1) - + 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), + sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long + or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), + labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits, ) + transformer_outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/modeling_kimi_k25.py b/modeling_kimi_k25.py new file mode 100644 index 0000000000000000000000000000000000000000..042fde845f33678e10b0723ca1ec6e65d59f8e17 --- /dev/null +++ b/modeling_kimi_k25.py @@ -0,0 +1,1248 @@ +# coding=utf-8 +# Copyright 2025-2026 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved. +# +# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for Kimi-K2.5. +# +# Licensing Information: +# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0. +# - Other parts of the code are licensed under the MIT License. +# +# Apache License, Version 2.0: +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License: +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math +from collections.abc import Sequence +from copy import deepcopy +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import activations + +try: + from transformers.activations import PytorchGELUTanh +except ImportError: + from transformers.activations import GELUTanh + activations.PytorchGELUTanh = GELUTanh + PytorchGELUTanh = GELUTanh +from transformers.activations import PytorchGELUTanh +from transformers.cache_utils import Cache +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llava.modeling_llava import \ + LlavaCausalLMOutputWithPast +from transformers.utils import is_flash_attn_2_available + +from .configuration_kimi_k25 import KimiK25Config +from .modeling_deepseek import DeepseekV3ForCausalLM + +# Flash attention imports +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func +else: + flash_attn_varlen_func = None + + +def multihead_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_cu_seqlens: torch.Tensor | None = None, + k_cu_seqlens: torch.Tensor | None = None, + max_seqlen_q: int | None = None, + max_seqlen_k: int | None = None, + deterministic: bool = False, +): + """Multi-head attention using flash attention 2. + + Args: + q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q. + The first element should be 0 and the last element should be q.shape[0]. + k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k. + The first element should be 0 and the last element should be k.shape[0]. + + Returns: + output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing, + where dim = num_heads * head_dim + """ + attn_out = flash_attn_varlen_func( + q, + k, + v, + q_cu_seqlens, + k_cu_seqlens, + max_seqlen_q, + max_seqlen_k, + causal=False, + deterministic=deterministic, + ) + if isinstance(attn_out, tuple): + attn_out = attn_out[0] + + attn_out = attn_out.flatten(start_dim=-2) + + return attn_out + + +def eager_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_cu_seqlens: Optional[torch.Tensor] = None, + k_cu_seqlens: Optional[torch.Tensor] = None, + **kwargs, +) -> torch.Tensor: + seq_length = q.shape[0] + attention_mask = torch.zeros([1, seq_length, seq_length], + device=q.device, + dtype=torch.bool) + for i in range(1, len(q_cu_seqlens)): + attention_mask[ + ..., + q_cu_seqlens[i - 1]:q_cu_seqlens[i], + q_cu_seqlens[i - 1]:q_cu_seqlens[i], + ] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1]) + attn_weight += attention_mask + attn_weight = torch.softmax(attn_weight, dim=-1, + dtype=torch.float32).to(q.dtype) + + attn_output = attn_weight @ v + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + return attn_output + + +VL_VISION_ATTENTION_FUNCTIONS = { + "flash_attention_2": multihead_attention, + "eager": eager_attention, +} + + +def _apply_rope_input_validation(x, freqs_cis): + assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape) + assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape) + assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape) + assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype + + +def get_rope_shape_decorate(func): + _get_rope_shape_first_call_flag = set() + + def wrapper(org, interpolation_mode, shape): + key = (org.requires_grad, torch.is_grad_enabled(), interpolation_mode) + if key not in _get_rope_shape_first_call_flag: + _get_rope_shape_first_call_flag.add(key) + _ = func(org, interpolation_mode, shape=(64, 64)) + return func(org, interpolation_mode, shape) + + return wrapper + + +@get_rope_shape_decorate +@torch.compile(dynamic=True) +def get_rope_shape(org, interpolation_mode, shape): + return (F.interpolate( + org.permute((2, 0, 1)).unsqueeze(0), + size=shape, + mode=interpolation_mode, + ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1)) + + +def apply_rope(xq: torch.Tensor, xk: torch.Tensor, + freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: (The leading dimensions of all inputs should be the same) + xq: query, tensor of shape (..., num_heads, head_dim) + xk: key, tensor of shape (..., num_heads, head_dim) + freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid. + Returns: + xq_out, xk_out: tensors of shape (..., num_heads, head_dim) + """ + _apply_rope_input_validation(xq, freqs_cis) + _apply_rope_input_validation(xk, freqs_cis) + + freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2 + # ..., num_heads, head_dim/2 + xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten( + -2) # ..., num_heads, head_dim + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten( + -2) # ..., num_heads, head_dim + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + From: + https://github.com/OpenGVLab/InternVideo/blob/421f6d2361fc8f61a3394244571f2601a4e99e29/InternVideo2/multi_modality/models/backbones/internvideo2/pos_embed.py#L86 + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False): + """ + t_size: int of the temporal size + return: + pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token) + """ + grid_t = np.arange(t_size, dtype=np.float32) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) + return pos_embed + + +class Learnable2DInterpPosEmbDivided_fixed(nn.Module): + + def __init__(self, + height: int, + width: int, + num_frames: int, + dim: int, + interpolation_mode: str = 'bicubic') -> None: + super().__init__() + self.height = height + self.width = width + self.num_frames = num_frames + self.dim = dim + self.interpolation_mode = interpolation_mode + self.weight = nn.Parameter(torch.empty(height, width, dim)) + self.register_buffer('time_weight', + torch.from_numpy( + get_1d_sincos_pos_embed( + self.dim, + self.num_frames)).float().unsqueeze(1), + persistent=False) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.weight) + + def forward(self, x: torch.Tensor, + grid_thws: torch.Tensor) -> torch.Tensor: + pos_embs = [] + for t, h, w in grid_thws.tolist(): + assert t <= self.num_frames, f't:{t} > self.num_frames:{self.num_frames}' + if (h, w) == self.weight.shape[:-1]: + pos_emb_2d = self.weight.flatten(end_dim=1) + else: + pos_emb_2d = get_rope_shape( + self.weight, + interpolation_mode=self.interpolation_mode, + shape=(h, w), + ) + + if t == 1: + pos_emb_3d = pos_emb_2d + else: + pos_emb_3d = pos_emb_2d.unsqueeze(0).repeat( + t, 1, 1) + self.time_weight[0:t] + + pos_embs.append(pos_emb_3d.reshape(-1, pos_emb_3d.shape[-1])) + + out = x + torch.cat(pos_embs) + return out + + +class MoonVision3dPatchEmbed(nn.Module): + + def __init__(self, + out_dim: int, + in_dim: int = 3, + patch_size: int | tuple[int, int] = (14, 14), + pos_emb_height: int = 14, + pos_emb_width: int = 14, + pos_emb_time: int = 4, + pos_emb_type: str = 'divided_fixed'): + super().__init__() + assert isinstance( + patch_size, + int | Sequence), f'Invalid patch_size type: {type(patch_size)}' + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + assert (len(patch_size) == 2 + ), f'Expected patch_size to be a tuple of 2, got {patch_size}' + self.patch_size = patch_size + + self.proj = nn.Conv2d(in_dim, + out_dim, + kernel_size=patch_size, + stride=patch_size) + + if pos_emb_type == 'divided_fixed': + self.pos_emb = Learnable2DInterpPosEmbDivided_fixed( + height=pos_emb_height, + width=pos_emb_width, + num_frames=pos_emb_time, + dim=out_dim) + else: + raise NotImplementedError( + f'Not support pos_emb_type: {pos_emb_type}') + + def forward(self, x: torch.Tensor, + grid_thws: torch.Tensor) -> torch.Tensor: + """ + Args: + x (L, Channels): input tensor + grid_hws (N, 3): temporal, height and width + + Returns: + (L, Cout) tensor + """ + x = self.proj(x).view(x.size(0), -1) + # apply positional embedding + x = self.pos_emb(x, grid_thws) + return x + + +class Rope2DPosEmbRepeated(nn.Module): + """2D rotary position embedding with multi-resolution support. + + This class is intended to be used in the following way: + 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis. + 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration. + 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation. + The rope is shared across all attention layers and all heads. + + Refs: + - RoFormer: https://arxiv.org/abs/2104.09864 + - VisionLLaMA: https://arxiv.org/abs/2403.00522 + - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py + + Args: + dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed) + max_height (int): the maximum height of the 2D grid + max_width (int): the maximum width of the 2D grid + theta_base (float): the base of the theta + device (str): the device to store the precomputed cis + """ + + def __init__(self, + dim: int, + max_height: int, + max_width: int, + theta_base=10000): + super().__init__() + self.dim = dim + assert self.dim % 4 == 0, 'dim must be divisible by 4' + self.max_height = max_height + self.max_width = max_width + self.theta_base = theta_base + + def extra_repr(self): + return f'dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}' + + def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor: + """Calculate the cis(freqs) for each position in the 2D grid. + + Return: complex tensor of shape (max_height, max_width, dim//2) and value: + height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim)) + weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4)) + note: `cis` is a mathematical notation defined by cis x = cos x + i sin x, + """ + N = self.max_height * self.max_width + flat_pos = torch.arange(0, N).float().to(device) + x_pos = flat_pos % self.max_width + y_pos = flat_pos // self.max_width + dim_range = (torch.arange(0, self.dim, + 4)[:(self.dim // 4)].float().to(device) + ) # C/4 + freqs = 1.0 / (self.theta_base**(dim_range / self.dim)) + x_freqs = torch.outer(x_pos, freqs).float() # N, C/4 + y_freqs = torch.outer(y_pos, freqs).float() # N, C/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4 + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4 + # N, C/4, 2 + freqs_cis = torch.cat( + [x_cis.unsqueeze(dim=-1), + y_cis.unsqueeze(dim=-1)], dim=-1) + # max_height, max_width, C/2 + freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1) + return freqs_cis + + def get_freqs_cis(self, grid_thws: torch.Tensor, + device: torch.device) -> torch.Tensor: + """ + Args: + grid_thws (torch.Tensor): grid time, height and width + + Returns: + freqs_cis: tensor of shape (sum(t * height * width), dim//2) + """ + if not hasattr(self, 'freqs_cis'): + self.register_buffer('freqs_cis', + self._precompute_freqs_cis(device), + persistent=False) + + shapes = grid_thws.tolist() + assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width + for t, h, w in shapes), ( + shapes, + self.max_height, + self.max_width, + ) + freqs_cis = torch.cat( + [ + self.freqs_cis[:h, :w].reshape(-1, self.dim // 2).repeat(t, 1) + for t, h, w in shapes + ], + dim=0, + ) + return freqs_cis + + +class MLP2(nn.Module): + """ + Args: + dims: [in_dim, hidden_dim, out_dim] + bias: whether to use bias in linear layer. + """ + + def __init__(self, dims: list[int], activation, bias=True): + super().__init__() + assert len(dims) == 3 + self.fc0 = nn.Linear(dims[0], dims[1], bias=bias) + self.fc1 = nn.Linear(dims[1], dims[2], bias=bias) + self.activation = activation + for m in [self.fc0, self.fc1]: + nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features)) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc0(x) + x = self.activation(x) + return self.fc1(x) + + +class MoonViTEncoderLayer(nn.Module): + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + *, + attn_implementation: str = 'flash_attention_2', + activation=F.gelu, + attn_bias: bool = False, + use_deterministic_attn: bool = False, + ): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads + self.attn_implementation = attn_implementation + self.use_deterministic_attn = use_deterministic_attn + + self.norm0 = nn.LayerNorm(hidden_dim) + self.norm1 = nn.LayerNorm(hidden_dim) + self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation) + self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias) + self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias) + + def attention_qkvpacked( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + rope_freqs_cis: torch.Tensor | None = None, + ): + """ + Args: + x (torch.Tensor): (batch_size, seqlen, hidden_dim) + cu_seqlens (torch.Tensor): + """ + xqkv = self.wqkv(x) + + qkv_shape = xqkv.size()[:-1] + ( + 3, + self.num_heads, + self.hidden_size_per_attention_head, + ) + # xqkv: (batch_size, seqlen, 3, nheads, headdim) + xqkv = xqkv.view(*qkv_shape) + xq, xk, xv = torch.unbind(xqkv, dim=-3) + + xq, xk = apply_rope(xq, xk, rope_freqs_cis) + + attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation] + attn_out = attn_func(xq, + xk, + xv, + q_cu_seqlens=cu_seqlens, + k_cu_seqlens=cu_seqlens, + max_seqlen_k=max_seqlen, + max_seqlen_q=max_seqlen, + deterministic=self.use_deterministic_attn) + + attn_out = self.wo(attn_out) + return attn_out + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + rope_freqs_cis: torch.Tensor | None = None, + ): + residual = hidden_states + hidden_states = self.norm0(hidden_states) + + hidden_states = self.attention_qkvpacked(hidden_states, cu_seqlens, + max_seqlen, rope_freqs_cis) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class MoonViT3dEncoder(nn.Module): + + def __init__(self, + hidden_dim: int, + num_layers: int, + block_cfg: dict, + video_attn_type: str = 'spatial_temporal') -> None: + super().__init__() + + assert video_attn_type == 'spatial_temporal', f'video_attn_type must be "spatial_temporal", got {video_attn_type}' + self.video_attn_type = video_attn_type + self.rope_2d = Rope2DPosEmbRepeated( + block_cfg['hidden_dim'] // block_cfg['num_heads'], 512, 512) + self.blocks = nn.ModuleList([ + MoonViTEncoderLayer( + **block_cfg, + use_deterministic_attn=self.use_deterministic_attn) + for _ in range(num_layers) + ]) + self.final_layernorm = nn.LayerNorm(hidden_dim) + + def forward( + self, + hidden_states: torch.Tensor, + grid_thws: torch.Tensor, + ) -> torch.Tensor: + rope_freqs_cis = self.rope_2d.get_freqs_cis( + grid_thws=grid_thws, device=hidden_states.device) + + lengths = torch.cat(( + torch.zeros(1, dtype=grid_thws.dtype, device=grid_thws.device), + grid_thws[:, 0] * grid_thws[:, 1] * grid_thws[:, 2], + )) + + max_seqlen = lengths.max() + cu_seqlens = lengths.to(hidden_states.device).cumsum(dim=0, + dtype=torch.int32) + for block in self.blocks: + hidden_states = block(hidden_states, + cu_seqlens, + max_seqlen, + rope_freqs_cis=rope_freqs_cis) + + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +def tpool_patch_merger( + x: torch.Tensor, + grid_thws: torch.Tensor, + merge_kernel_size: tuple[int, int] = (2, 2), +) -> list[torch.Tensor]: + d_model = x.size(-1) + + outputs = [] + pre_sum = 0 + for t, h, w in grid_thws.tolist(): + # Get the current sequence + seq = x[pre_sum:pre_sum + t * h * w] + # Reshape along self.merge_kernel_size and concat to the last dimension + kernel_height, kernel_width = merge_kernel_size + new_height, new_width = h // kernel_height, w // kernel_width + reshaped_seq = seq.view(t, new_height, kernel_height, new_width, + kernel_width, d_model) + reshaped_seq = reshaped_seq.permute(0, 1, + 3, 2, 4, 5).contiguous().mean( + dim=0) # temporal pooling + padded_seq = reshaped_seq.view(new_height * new_width, + kernel_height * kernel_width, -1) + outputs.append(padded_seq) + pre_sum += t * h * w + + return outputs + + +class MoonViT3dPretrainedModel(PreTrainedModel): + config_class = None + model_type = 'moonvit3d' + _no_split_modules = ['PackingTransformer'] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + config = deepcopy(config) + self.merge_kernel_size = config.merge_kernel_size + self.patch_size = config.patch_size + self.merge_type = config.merge_type + + self.patch_embed = MoonVision3dPatchEmbed( + out_dim=config.hidden_size, + patch_size=config.patch_size, + pos_emb_height=config.init_pos_emb_height, + pos_emb_width=config.init_pos_emb_width, + pos_emb_time=config.init_pos_emb_time, + pos_emb_type=config.pos_emb_type, + ) + + self.encoder = MoonViT3dEncoder(hidden_dim=config.hidden_size, + num_layers=config.num_hidden_layers, + block_cfg={ + 'num_heads': + config.num_attention_heads, + 'hidden_dim': + config.hidden_size, + 'mlp_dim': + config.intermediate_size, + 'activation': + PytorchGELUTanh(), + 'attn_bias': + True, + 'attn_implementation': + config._attn_implementation, + }, + video_attn_type=config.video_attn_type) + + def forward(self, pixel_values: torch.Tensor, + grid_thws: torch.Tensor) -> torch.Tensor: + """ + Args: + pixel_values (torch.Tensor): The input pixel values. + grid_thws (torch.Tensor): Temporal, height and width. + + Returns: + torch.Tensor: The output tokens. + """ + # grid_thws = grid_thws.to('cpu') + assert grid_thws.ndim == 2, f'grid_thws should be 2D, got {grid_thws.ndim}' + assert grid_thws.size(1) == 3, f'No support for thw: {grid_thws}' + hidden_states = self.patch_embed(pixel_values, grid_thws) + hidden_states = self.encoder(hidden_states, grid_thws) + if self.merge_type == 'sd2_tpool': # spatial downsampling 2x with temporal pooling all + hidden_states = tpool_patch_merger( + hidden_states, + grid_thws, + merge_kernel_size=self.merge_kernel_size) + else: + raise NotImplementedError(f'Not support {self.merge_type}') + + return hidden_states + + +# ============================================================================ +# MM Projector Helper Classes (from mm_projector/modeling_mm_projectors.py) +# ============================================================================ + + +class IdentityMap(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + +class MLP(nn.Module): + + def __init__(self, config): + super().__init__() + # TODO, use faster LayerNorm + self.pre_norm = nn.LayerNorm(config.mm_hidden_size) + self.proj = nn.Sequential( + nn.Linear(config.mm_hidden_size, config.hidden_size), nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size)) + + def forward(self, x, *args, **kwargs): + assert isinstance(x, + list | tuple), f'x is not a list or tuple: {type(x)}' + lengths = [item.shape[0] for item in x] + x = torch.cat(x, dim=0) + x = self.pre_norm(x) + x = self.proj(x) + x = torch.split(x, lengths, dim=0) + + return x + + +class PatchMergerMLP(nn.Module): + + def __init__(self, config): + super().__init__() + eps = config.projector_ln_eps + self.hidden_size = config.mm_hidden_size * ( + config.merge_kernel_size[0] * config.merge_kernel_size[1]) + self.pre_norm = nn.LayerNorm(config.mm_hidden_size, eps=eps) + self.proj = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, config.hidden_size), + ) + + def forward(self, x, *args, **kwargs): + if isinstance(x, list) or isinstance(x, tuple): + x = [ + self.proj(self.pre_norm(item).view(item.shape[0], -1)) + for item in x + ] + else: + # B, N, N_k, C = x.shape + B = x.shape[0] + x = self.proj(self.pre_norm(x).view(B, -1, self.hidden_size)) + return x + + +class KimiK25PreTrainedModel(PreTrainedModel): + config_class = KimiK25Config + base_model_prefix = "model" + _no_split_modules = [ + "MoonViT3dPretrainedModel", + "MoonViTEncoderLayer", + "DeepseekDecoderLayer", + "PatchMergerMLP", + ] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = False + + def _init_weights(self, module): + # important: this ported version of Llava isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose + std = (self.config.initializer_range if hasattr( + self.config, "initializer_range") else + self.config.text_config.initializer_range) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class VisionTowerConfig(PretrainedConfig): + model_type = 'moonvit3d' + + def __init__(self, config: KimiK25Config, **kwargs): + super().__init__(**kwargs) + self.patch_size = config.patch_size + self.init_pos_emb_height = config.init_pos_emb_height + self.init_pos_emb_width = config.init_pos_emb_width + self.init_pos_emb_time = config.init_pos_emb_time + self.pos_emb_type = config.pos_emb_type + self.num_attention_heads = config.vt_num_attention_heads + self.num_hidden_layers = config.vt_num_hidden_layers + self.hidden_size = config.vt_hidden_size + self.intermediate_size = config.vt_intermediate_size + self.merge_kernel_size = config.merge_kernel_size + self.video_attn_type = config.video_attn_type + self.merge_type = config.merge_type + self._attn_implementation = config._attn_implementation + + +class ProjectorConfig: + + def __init__(self, config: KimiK25Config): + self.mm_projector_type = config.mm_projector_type + self.mm_hidden_size = config.mm_hidden_size + self.hidden_size = config.text_hidden_size + self.merge_kernel_size = config.merge_kernel_size + self.projector_hidden_act = config.projector_hidden_act + self.projector_ln_eps = config.projector_ln_eps + + +# ref https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/llava/modeling_llava.py#L240 +class KimiK25ForConditionalGeneration(KimiK25PreTrainedModel): + + def __init__(self, config: KimiK25Config): + super().__init__(config) + + vt_config = VisionTowerConfig(config.vision_config) + self.vision_tower = MoonViT3dPretrainedModel(vt_config) + + proj_config = ProjectorConfig(config.vision_config) + if proj_config.mm_projector_type == 'identity': + self.mm_projector = IdentityMap() + elif proj_config.mm_projector_type == 'mlp': + self.mm_projector = MLP(proj_config) + elif proj_config.mm_projector_type == 'patchmerger': + self.mm_projector = PatchMergerMLP(proj_config) + else: + raise ValueError( + f"Unsupported mm_projector_type: {proj_config.mm_projector_type}" + ) + + self.language_model = DeepseekV3ForCausalLM(config.text_config) + self.post_init() + + if hasattr(self.language_model, 'dtype'): + target_dtype = self.language_model.dtype + self.vision_tower = self.vision_tower.to(dtype=target_dtype) + self.mm_projector = self.mm_projector.to(dtype=target_dtype) + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings(self, + new_num_tokens: int | None = None, + pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings( + new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _merge_input_ids_with_image_features( + self, + image_features: list[torch.Tensor], + inputs_embeds: torch.Tensor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels: torch.Tensor | None = None, + ): + """ + Args: + image_features (:obj:`torch.Tensor` of shape :obj:`(num_image_tokens, embed_dim)`): + The image features to merge with the input embeddings. + inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length, embed_dim)`): + The input embeddings. + input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`): + The input ids. + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`): + The attention mask. + labels (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, *optional*): + The labels. + """ + _, embed_dim = image_features[0].shape + feature_lengths = [x.shape[0] for x in image_features] + image_features = torch.cat(image_features, dim=0) + + image_token_index: int = self.config.media_placeholder_token_id + pad_token_id: int = self.config.pad_token_id + ignore_index: int = self.config.ignore_index + + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum( + input_ids[:, -1] == torch.tensor(pad_token_id)) + + # 1. Create a mask to know where special image tokens are + _token_occupation_table = torch.ones_like(input_ids.flatten()) + _token_occupation_table[input_ids.flatten() == + image_token_index] = torch.tensor( + feature_lengths, + dtype=torch.long, + device=input_ids.device) + _token_occupation_table = _token_occupation_table.reshape( + input_ids.shape) + + max_embed_dim = _token_occupation_table.sum(-1).max().item() + assert ( + max_embed_dim >= sequence_length + ), f"The maximum embedding dimension ({max_embed_dim}) is less than the sequence length ({sequence_length})" + batch_indices, non_image_indices = torch.where( + input_ids != image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + new_token_positions = torch.cumsum(_token_occupation_table, -1) - 1 + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, + None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, + non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, + max_embed_dim, + embed_dim, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + final_attention_mask = torch.zeros(batch_size, + max_embed_dim, + dtype=attention_mask.dtype, + device=inputs_embeds.device) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), + ignore_index, + dtype=input_ids.dtype, + device=input_ids.device, + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. + final_embedding[batch_indices, + text_to_overwrite] = inputs_embeds[batch_indices, + non_image_indices] + final_attention_mask[batch_indices, + text_to_overwrite] = attention_mask[ + batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, + text_to_overwrite] = labels[batch_indices, + non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full((batch_size, max_embed_dim), + True, + dtype=torch.bool, + device=inputs_embeds.device) + image_to_overwrite[batch_indices, text_to_overwrite] = False + image_to_overwrite &= image_to_overwrite.cumsum( + -1) - 1 >= nb_image_pad[:, None].to(target_device) + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {image_to_overwrite.sum()} while" + f" the number of image features given to the model is {image_features.shape[:-1].numel()}. " + "This prevents correct indexing and breaks batch generation.") + + final_embedding[image_to_overwrite] = ( + image_features.contiguous().reshape(-1, + embed_dim).to(target_device)) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_( + (final_attention_mask == 0), 1) + + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids + + def _extract_image_features(self, pixel_values: torch.Tensor, + grid_thws: torch.Tensor) -> list[torch.Tensor]: + """ + Args: + pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`): + The pixel values of the images processed by image processor. + grid_thws (:obj:`torch.Tensor` of shape :obj:`(batch_size, 3)`): + The grid, height, width of the images. + + Returns: + selected_image_feature (:obj:`torch.FloatTensor` of shape :obj:`(num_image_tokens, embed_dim)`): + The selected image features to use as input to the projector head. + + """ + + target_dtype = self.vision_tower.patch_embed.proj.weight.dtype + pixel_values = pixel_values.to(target_dtype) + + image_features = self.vision_tower(pixel_values, grid_thws) + return image_features + + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | list[torch.FloatTensor] + | None = None, + grid_thws: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | LlavaCausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + ```""" + assert self.vision_tower is not None, "vision_tower is not loaded" + output_attentions = (output_attentions if output_attentions is not None + else self.config.output_attentions) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + # 1. Extra the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and len( + pixel_values) > 0 and input_ids.shape[1] != 1: + image_features = self._extract_image_features( + pixel_values, grid_thws) + if self.mm_projector: + image_features = self.mm_projector(image_features) + + inputs_embeds = inputs_embeds.to( + image_features[0].dtype) # num_tokens, embed_dim + inputs_embeds, attention_mask, labels, position_ids = ( + self._merge_input_ids_with_image_features( + image_features, + inputs_embeds, + input_ids, + attention_mask, + labels, + )) + + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + elif (past_key_values is not None and pixel_values is not None + and input_ids.shape[1] == 1): + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where( + first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size( + -1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, + new_non_attended_tokens] = 0 + + attention_mask = torch.cat( + (extended_attention_mask, attention_mask[:, + -target_length:]), + dim=1) + position_ids = torch.sum(attention_mask, + dim=1).unsqueeze(-1) - 1 + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask.to( + logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to( + labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1).to(shift_logits.device), + ) + + if not return_dict: + output = (logits, ) + outputs[1:] + return (loss, ) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + grid_thws=None, + attention_mask=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = getattr(past_key_values, 'seen_tokens', + cache_length) + else: + cache_length = past_length = past_key_values[0][0].shape[2] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[ + 1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - + past_length):] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + elif self.config.media_placeholder_token_id in input_ids: + input_ids = input_ids[:, input_ids.shape[1] - 1:] + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + + input_ids.shape[1]):] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update({ + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "grid_thws": grid_thws, + }) + return model_inputs + + def _reorder_cache(self, *args, **kwargs): + return self.language_model._reorder_cache(*args, **kwargs) diff --git a/preprocessor_config.json b/preprocessor_config.json new file mode 100644 index 0000000000000000000000000000000000000000..ede361dcd63890f38249f5a8578124d287d6dc92 --- /dev/null +++ b/preprocessor_config.json @@ -0,0 +1,32 @@ +{ + "auto_map": { + "AutoImageProcessor": "kimi_k25_vision_processing.KimiK25VisionProcessor", + "AutoProcessor": "kimi_k25_processor.KimiK25Processor" + }, + "image_processor_type": "KimiK25VisionProcessor", + "media_proc_cfg": { + "config_type": "media_proc.processors.moonvit.MoonViTMediaProcessorConfig", + "fixed_output_tokens": null, + "image_mean": [ + 0.5, + 0.5, + 0.5 + ], + "image_std": [ + 0.5, + 0.5, + 0.5 + ], + "in_patch_limit": 16384, + "in_patch_limit_each_frame": 4096, + "in_patch_limit_video": null, + "max_num_frames_each_video": null, + "merge_kernel_size": 2, + "patch_limit_on_one_side": 512, + "patch_size": 14, + "sample_fps": 2.0, + "temporal_merge_kernel_size": 4, + "timestamp_mode": "hh:mm:ss.fff" + }, + "num_frames_per_chunk": 4 +} \ No newline at end of file diff --git a/processor_config.json b/processor_config.json new file mode 100644 index 0000000000000000000000000000000000000000..31838a3332c7aad5652f4fa910b2f5c3cae3636c --- /dev/null +++ b/processor_config.json @@ -0,0 +1,6 @@ +{ + "auto_map": { + "AutoProcessor": "kimi_k25_processor.KimiK25Processor" + }, + "processor_class": "KimiK25Processor" +} diff --git a/special_tokens_map.json b/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..75044cb3b83b5712ec5b8db5c81549097845339d --- /dev/null +++ b/special_tokens_map.json @@ -0,0 +1,44 @@ +{ + "additional_special_tokens": [ + "<|im_end|>", + "<|im_user|>", + "<|im_assistant|>", + "<|start_header_id|>", + "<|end_header_id|>", + "[EOT]", + "<|im_system|>", + "<|im_middle|>", + "<|media_begin|>", + "<|media_content|>", + "<|media_end|>", + "<|media_pad|>" + ], + "bos_token": { + "content": "[BOS]", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "[EOS]", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "[PAD]", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "[UNK]", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tiktoken.model b/tiktoken.model new file mode 100644 index 0000000000000000000000000000000000000000..b4149a6e17a01b6442187f39890f89bc2fe8d309 --- /dev/null +++ b/tiktoken.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6c497a7469b33ced9c38afb1ad6e47f03f5e5dc05f15930799210ec050c5103 +size 2795286 diff --git a/tokenization_kimi.py b/tokenization_kimi.py new file mode 100644 index 0000000000000000000000000000000000000000..adfb06553c48f258546246e6ef4e1d85ad9ef711 --- /dev/null +++ b/tokenization_kimi.py @@ -0,0 +1,351 @@ +import os +from collections import OrderedDict +from logging import getLogger +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast + +import tiktoken +from tiktoken.load import load_tiktoken_bpe +from tokenizers import AddedToken +from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode +from transformers.tokenization_utils import PreTrainedTokenizer + +from .tool_declaration_ts import encode_tools_to_typescript_style + +logger = getLogger(__name__) +VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"} + + +class TikTokenTokenizer(PreTrainedTokenizer): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + The path to the Tiktoken model file. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`): + The end of sequence token. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. The second to last item in special_tokens. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`): + The token used for padding, for example when batching sequences of different lengths. + additional_special_tokens (list of `str`, *optional*): + A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be + skipped when decoding if `skip_special_tokens` is set to `True`. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + model_input_names = ["input_ids", "attention_mask"] + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = "|".join([ + r"""[\p{Han}]+""", + r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", + r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", + r"""\p{N}{1,3}""", + r""" ?[^\s\p{L}\p{N}]+[\r\n]*""", + r"""\s*[\r\n]+""", + r"""\s+(?!\S)""", + r"""\s+""", + ]) + + def __init__( + self, + vocab_file, + bos_token: Union[str, AddedToken] = "[BOS]", + eos_token: Union[str, AddedToken] = "[EOS]", + unk_token: Union[str, AddedToken, None] = None, + pad_token: Union[str, AddedToken, None] = None, + additional_special_tokens: List[str] = None, + added_tokens_decoder: Optional[dict] = None, + **kwargs, + ): + assert os.path.isfile(vocab_file), vocab_file + + if additional_special_tokens is None: + additional_special_tokens = [ + "<|im_end|>", + "<|im_user|>", + "<|im_assistant|>", + "<|start_header_id|>", + "<|end_header_id|>", + "[EOT]", + "<|im_system|>", + "<|im_middle|>", + ] + + if added_tokens_decoder: + special_tokens_mapping = { + i: added_tokens_decoder[i].content + for i in added_tokens_decoder + } + else: + special_tokens_mapping = {} + + self.vocab_file = vocab_file + mergeable_ranks = load_tiktoken_bpe(vocab_file) + num_base_tokens = len(mergeable_ranks) + self.special_tokens = { + special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i + for i in range(num_base_tokens, num_base_tokens + + self.num_reserved_special_tokens) + } + + self.model = tiktoken.Encoding( + name=Path(vocab_file).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + logger.info(f"Reloaded tiktoken model from {vocab_file}") + + self.n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens[str(bos_token)] + self.eos_id: int = self.special_tokens[str(eos_token)] + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + + self.pad_id: int = self.special_tokens[str(pad_token)] + self.unk_id: int = self.special_tokens[str(unk_token)] + + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + self.decoder = {} + for i in range(self.n_words): + # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee + decoding = ''.join([ + self.byte_encoder[ord(char)] for char in + self.model.decode_single_token_bytes(i).decode('latin-1') + ]) + self.decoder[i] = decoding + + self.encoder = {} + for i in range(self.n_words): + if i in self.decoder: + self.encoder[self.decoder[i]] = i + + self._token_config_cache = OrderedDict() + self._cache_max_size = 128 + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + additional_special_tokens=additional_special_tokens, + added_tokens_decoder=added_tokens_decoder, + **kwargs, + ) + self.all_special_ids_set = set(self.all_special_ids) + + def encode(self, + text: str, + allow_special_tokens: bool = True, + **kwargs) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + text (str): The input string to be encoded. + + Returns: + list[int]: A list of token IDs. + """ + # If there are other args, we should call super().encode because there are a lot of code + # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id. + # NOTE: our encode method is not compatible with the super().encode method, + # e.g. split_special_tokens' default is True in our encode method. + if len(kwargs) > 0: + logger.warning(f"Calling super().encode with {kwargs}") + return super().encode(text, **kwargs) + + assert type(text) is str + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + texts = self.pre_tokenizer_process(text) + + all_substrs = [] + for text in texts: + substrs = ( + substr for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + text[i:i + + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS)) + all_substrs.extend(substrs) + + t: List[int] = [] + for substr in all_substrs: + if allow_special_tokens: + t.extend( + # we should consider special token as a common token + self.model.encode( + substr, + allowed_special="all", + )) + else: + t.extend( + # we should consider special token as a common token + self.model.encode( + substr, + disallowed_special=(), + )) + + return t + + def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + token_ids (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # If there are other args, we should call super().decode because there are a lot of code + # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token. + if len(kwargs) > 0: + return super().decode(token_ids, **kwargs) + + if type(token_ids) is int: + token_ids = [token_ids] + + return self.model.decode(cast(List[int], token_ids)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + def pre_tokenizer_process(self, text: str) -> List[str]: + """ + pre-tokenizes the input text into a list of tokens. + This method is used to split the input text into smaller chunks for internal processing. + """ + return [text] + + """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """ + + @property + def vocab_size(self) -> int: + return self.n_words + + def get_vocab(self) -> Dict[str, int]: + return self.encoder + + def _tokenize(self, text: str, **kwargs) -> List[str]: + return [self.decoder[t] for t in self.encode(text)] + + def _convert_token_to_id(self, token: str) -> int: + return self.encoder.get(token, self.unk_id) + + def _convert_id_to_token(self, index: int) -> str: + return self.decoder.get(index) + + @staticmethod + def clean_up_tokenization(out_string: str) -> str: + return out_string + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + text = ''.join(tokens) + text = bytearray([self.byte_decoder[c] + for c in text]).decode('utf-8', 'replace') + return text + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + raise ValueError( + f"vocabulary path ({save_directory}) should be a directory") + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"]) + + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file, ) + + def apply_chat_template(self, + conversation, + tools: Optional[list[dict]] = None, + tokenize: bool = False, + add_generation_prompt: bool = True, + thinking: bool = True, + **kwargs): + + tools = deep_sort_dict(tools) + + # Convert tools to TypeScript style string if tools are provided + tools_ts_str = None + if tools: + try: + tools_ts_str = encode_tools_to_typescript_style(tools) + + except Exception as e: + print(f"Failed to convert tools to TypeScript style: {e}") + tools_ts_str = None + + # Store the TypeScript string in kwargs so it can be accessed by the template + if tools_ts_str is not None: + kwargs['tools_ts_str'] = tools_ts_str + return super().apply_chat_template( + conversation, + tools=tools, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + thinking=thinking, + **kwargs) + + +def deep_sort_dict(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: deep_sort_dict(v) for k, v in sorted(obj.items())} + if isinstance(obj, list): + return [deep_sort_dict(item) for item in obj] + return obj diff --git a/tokenizer_config.json b/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..15ebc2bf406ed3534084931a4ade35a86e1970f1 --- /dev/null +++ b/tokenizer_config.json @@ -0,0 +1,220 @@ +{ + "added_tokens_decoder": { + "163584": { + "content": "[BOS]", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163585": { + "content": "[EOS]", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163586": { + "content": "<|im_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163587": { + "content": "<|im_user|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163588": { + "content": "<|im_assistant|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163590": { + "content": "<|start_header_id|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163591": { + "content": "<|end_header_id|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163593": { + "content": "[EOT]", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163594": { + "content": "<|im_system|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163595": { + "content": "<|tool_calls_section_begin|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "163596": { + "content": "<|tool_calls_section_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "163597": { + "content": "<|tool_call_begin|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "163598": { + "content": "<|tool_call_argument_begin|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "163599": { + "content": "<|tool_call_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "163601": { + "content": "<|im_middle|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163602": { + "content": "<|media_begin|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163603": { + "content": "<|media_content|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163604": { + "content": "<|media_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163605": { + "content": "<|media_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163606": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "163607": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "163838": { + "content": "[UNK]", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "163839": { + "content": "[PAD]", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "<|im_end|>", + "<|im_user|>", + "<|im_assistant|>", + "<|start_header_id|>", + "<|end_header_id|>", + "[EOT]", + "<|im_system|>", + "<|im_middle|>", + "<|media_begin|>", + "<|media_content|>", + "<|media_end|>", + "<|media_pad|>" + ], + "auto_map": { + "AutoProcessor": "kimi_k25_processor.KimiK25Processor", + "AutoTokenizer": [ + "tokenization_kimi.TikTokenTokenizer", + null + ] + }, + "bos_token": "[BOS]", + "clean_up_tokenization_spaces": false, + "eos_token": "[EOS]", + "extra_special_tokens": {}, + "model_max_length": 262144, + "pad_token": "[PAD]", + "padding_side": "left", + "processor_class": "KimiK25Processor", + "tokenizer_class": "TikTokenTokenizer", + "unk_token": "[UNK]", + "chat_template": "{%- macro render_content(msg) -%}\n {%- set c = msg.get('content') -%}\n {%- if c is string -%}\n {{ c }}\n {%- elif c is not none -%}\n {% for content in c -%}\n {% if content['type'] == 'image' or content['type'] == 'image_url' -%}\n <|media_start|>image<|media_content|><|media_pad|><|media_end|>\n {% elif content['type'] == 'video' or content['type']== 'video_url'-%}\n <|kimi_k25_video_placeholder|>\n {% else -%}\n {{ content['text'] }}\n {%- endif -%}\n {%- endfor -%}\n {%- endif -%}\n{%- endmacro -%}\n\n{% macro set_roles(message) -%}\n {%- set role_name = message.get('name') or message['role'] -%}\n {%- if message['role'] == 'user' -%}\n <|im_user|>{{role_name}}<|im_middle|>\n {%- elif message['role'] == 'assistant' -%}\n <|im_assistant|>{{role_name}}<|im_middle|>\n {%- else -%}\n <|im_system|>{{role_name}}<|im_middle|>\n {%- endif -%}\n{%- endmacro -%}\n\n\n{%- macro render_toolcalls(message) -%}\n <|tool_calls_section_begin|>\n {%- for tool_call in message['tool_calls'] -%}\n {%- set formatted_id = tool_call['id'] -%}\n <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>\n {%- endfor -%}\n <|tool_calls_section_end|>\n{%- endmacro -%}\n\n\n{# Find last non-tool-call assisitant message #}\n{%- set ns = namespace(last_non_tool_call_assistant_msg=-1) -%}\n{%- for idx in range(messages|length-1, -1, -1) -%}\n {%- if messages[idx]['role'] == 'assistant' and not messages[idx].get('tool_calls') -%}\n {%- set ns.last_non_tool_call_assistant_msg = idx -%}\n {%- break -%}\n {%- endif -%}\n{%- endfor -%}\n\n{# split all messages into history & suffix, reasoning_content in suffix should be reserved.#}\n{%- set hist_msgs = messages[:ns.last_non_tool_call_assistant_msg+1] -%}\n{%- set suffix_msgs = messages[ns.last_non_tool_call_assistant_msg+1:] -%}\n\n{%- if tools -%}\n {%- if tools_ts_str -%}\n <|im_system|>tool_declare<|im_middle|>{{ tools_ts_str }}<|im_end|>\n {%- else -%}\n <|im_system|>tool_declare<|im_middle|>{{ tools | tojson(separators=(',', ':')) }}<|im_end|>\n {%- endif -%}\n{%- endif -%}\n\n{%- if messages|length == 0 or messages[0]['role'] != 'system' -%}\n <|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|>\n{%- endif -%}\n \n{%- for message in hist_msgs -%}\n {{set_roles(message)}}\n {%- if message['role'] == 'assistant' -%}\n {{render_content(message)}}\n {%- if message.get('tool_calls') -%}\n {{render_toolcalls(message)}}\n {%- endif -%}\n {%- elif message['role'] == 'tool' -%}\n {%- set tool_call_id = message.tool_call_id -%}\n ## Return of {{ tool_call_id }}\n{{render_content(message)}}\n {%- elif message['content'] is not none -%}\n {{render_content(message)}}\n {%- endif -%}\n <|im_end|>\n{%- endfor -%}\n\n{%- for message in suffix_msgs -%}\n {{set_roles(message)}}\n {%- if message['role'] == 'assistant' -%}\n {%- if thinking is defined and thinking is false -%}\n {{render_content(message)}}\n {%- else -%}\n {%- set rc = message.get('reasoning_content', '') -%}\n {{rc}}{{render_content(message)}}\n {%- endif -%}\n {%- if message.get('tool_calls') -%}\n {{render_toolcalls(message)}}\n {%- endif -%}\n {%- elif message['role'] == 'tool' -%}\n {%- set tool_call_id = message.tool_call_id -%}\n ## Return of {{ tool_call_id }}\n{{render_content(message)}}\n {%- elif message['content'] is not none -%}\n {{render_content(message)}}\n {%- endif -%}\n <|im_end|>\n{%- endfor -%}\n\n\n{%- if add_generation_prompt -%}\n <|im_assistant|>assistant<|im_middle|>\n {%- if thinking is defined and thinking is false -%}\n \n {%- else -%}\n \n {%- endif -%}\n{%- endif -%}" +} \ No newline at end of file diff --git a/tool_declaration_ts.py b/tool_declaration_ts.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc7727ddbeabef16b90ce0219446bc2d4ea9032 --- /dev/null +++ b/tool_declaration_ts.py @@ -0,0 +1,479 @@ +""" +Encode structured tool declaration to typescript style string. +""" +import dataclasses +import json +import logging +from collections.abc import Sequence +from typing import Any + +logger = logging.getLogger(__name__) + +_TS_INDENT = " " +_TS_FIELD_DELIMITER = ",\n" + + +class _SchemaRegistry: + """Registry for schema definitions to handle $ref resolution""" + + def __init__(self): + self.definitions = {} + self.has_self_ref = False + + def register_definitions(self, defs: dict[str, Any]): + """Register schema definitions from $defs section""" + if not defs: + return + for def_name, def_schema in defs.items(): + self.definitions[def_name] = def_schema + + def resolve_ref(self, ref: str) -> dict[str, Any]: + """Resolve a reference to its schema definition""" + if ref == "#": + self.has_self_ref = True + return {"$self_ref": True} + elif ref.startswith("#/$defs/"): + def_name = ref.split("/")[-1] + if def_name not in self.definitions: + raise ValueError(f"Reference not found: {ref}") + return self.definitions[def_name] + else: + raise ValueError(f"Unsupported reference format: {ref}") + + +def _format_description(description: str, indent: str = "") -> str: + return "\n".join([ + f"{indent}// {line}" if line else "" + for line in description.split("\n") + ]) + + +class _BaseType: + description: str + constraints: dict[str, Any] + + def __init__( + self, + extra_props: dict[str, Any], + *, + allowed_constraint_keys: Sequence[str] = (), + ): + self.description = extra_props.get("description", "") + self.constraints = { + k: v + for k, v in extra_props.items() if k in allowed_constraint_keys + } + + def to_typescript_style(self, indent: str = "") -> str: + raise NotImplementedError + + def format_docstring(self, indent: str) -> str: + lines = [] + if self.description: + lines.append(_format_description(self.description, indent)) + if self.constraints: + constraints_str = ", ".join(f"{k}: {v}" for k, v in sorted( + self.constraints.items(), key=lambda kv: kv[0])) + lines.append(f"{indent}// {constraints_str}") + + return "".join(x + "\n" for x in lines) + + +class _ParameterTypeScalar(_BaseType): + type: str + + def __init__(self, type: str, extra_props: dict[str, Any] | None = None): + self.type = type + + allowed_constraint_keys: list[str] = [] + if self.type == "string": + allowed_constraint_keys = ["maxLength", "minLength", "pattern"] + elif self.type in ("number", "integer"): + allowed_constraint_keys = ["maximum", "minimum"] + + super().__init__(extra_props or {}, + allowed_constraint_keys=allowed_constraint_keys) + + def to_typescript_style(self, indent: str = "") -> str: + # Map integer to number in TypeScript + if self.type == "integer": + return "number" + return self.type + + +class _ParameterTypeObject(_BaseType): + properties: list["_Parameter"] + additional_properties: Any | None = None + + def __init__(self, + json_schema_object: dict[str, Any], + registry: _SchemaRegistry | None = None): + super().__init__(json_schema_object) + + self.properties = [] + self.additional_properties = None + + if not json_schema_object: + return + + if "$defs" in json_schema_object and registry: + registry.register_definitions(json_schema_object["$defs"]) + + self.additional_properties = json_schema_object.get( + "additionalProperties") + if isinstance(self.additional_properties, dict): + self.additional_properties = _parse_parameter_type( + self.additional_properties, registry) + + if "properties" not in json_schema_object: + return + + required_parameters = json_schema_object.get("required", []) + optional_parameters = set( + json_schema_object["properties"].keys()) - set(required_parameters) + + self.properties = [ + _Parameter( + name=name, + type=_parse_parameter_type(prop, registry), + optional=name in optional_parameters, + default=prop.get("default") + if isinstance(prop, dict) else None, + ) for name, prop in json_schema_object["properties"].items() + ] + + def to_typescript_style(self, indent: str = "") -> str: + # sort by optional, make the required parameters first + parameters = [p for p in self.properties if not p.optional] + opt_params = [p for p in self.properties if p.optional] + + parameters = sorted(parameters, key=lambda p: p.name) + parameters.extend(sorted(opt_params, key=lambda p: p.name)) + + param_strs = [] + for p in parameters: + one = p.to_typescript_style(indent=indent + _TS_INDENT) + param_strs.append(one) + + if self.additional_properties is not None: + ap_type_str = "any" + if self.additional_properties is True: + ap_type_str = "any" + elif self.additional_properties is False: + ap_type_str = "never" + elif isinstance(self.additional_properties, _ParameterType): + ap_type_str = self.additional_properties.to_typescript_style( + indent=indent + _TS_INDENT) + else: + raise ValueError( + f"Unknown additionalProperties: {self.additional_properties}" + ) + param_strs.append( + f"{indent + _TS_INDENT}[k: string]: {ap_type_str}") + + if not param_strs: + return "{}" + + params_str = _TS_FIELD_DELIMITER.join(param_strs) + if params_str: + # add new line before and after + params_str = f"\n{params_str}\n" + # always wrap with object + return f"{{{params_str}{indent}}}" + + +class _ParameterTypeArray(_BaseType): + item: "_ParameterType" + + def __init__(self, + json_schema_object: dict[str, Any], + registry: _SchemaRegistry | None = None): + super().__init__(json_schema_object, + allowed_constraint_keys=("minItems", "maxItems")) + if json_schema_object.get("items"): + self.item = _parse_parameter_type(json_schema_object["items"], + registry) + else: + self.item = _ParameterTypeScalar(type="any") + + def to_typescript_style(self, indent: str = "") -> str: + item_docstring = self.item.format_docstring(indent + _TS_INDENT) + if item_docstring: + return ("Array<\n" + item_docstring + indent + _TS_INDENT + + self.item.to_typescript_style(indent=indent + _TS_INDENT) + + "\n" + indent + ">") + else: + return f"Array<{self.item.to_typescript_style(indent=indent)}>" + + +class _ParameterTypeEnum(_BaseType): + # support scalar types only + enum: list[str | int | float | bool | None] + + def __init__(self, json_schema_object: dict[str, Any]): + super().__init__(json_schema_object) + self.enum = json_schema_object["enum"] + + # Validate enum values against declared type if present + if "type" in json_schema_object: + typ = json_schema_object["type"] + if isinstance(typ, list): + if len(typ) == 1: + typ = typ[0] + elif len(typ) == 2: + if "null" not in typ: + raise ValueError(f"Enum type {typ} is not supported") + else: + typ = typ[0] if typ[0] != "null" else typ[1] + else: + raise ValueError(f"Enum type {typ} is not supported") + for val in self.enum: + if val is None: + continue + if typ == "string" and not isinstance(val, str): + raise ValueError(f"Enum value {val} is not a string") + elif typ == "number" and not isinstance(val, (int, float)): + raise ValueError(f"Enum value {val} is not a number") + elif typ == "integer" and not isinstance(val, int): + raise ValueError(f"Enum value {val} is not an integer") + elif typ == "boolean" and not isinstance(val, bool): + raise ValueError(f"Enum value {val} is not a boolean") + + def to_typescript_style(self, indent: str = "") -> str: + return " | ".join( + [f'"{e}"' if isinstance(e, str) else str(e) for e in self.enum]) + + +class _ParameterTypeAnyOf(_BaseType): + types: list["_ParameterType"] + + def __init__( + self, + json_schema_object: dict[str, Any], + registry: _SchemaRegistry | None = None, + ): + super().__init__(json_schema_object) + self.types = [ + _parse_parameter_type(t, registry) + for t in json_schema_object["anyOf"] + ] + + def to_typescript_style(self, indent: str = "") -> str: + return " | ".join( + [t.to_typescript_style(indent=indent) for t in self.types]) + + +class _ParameterTypeUnion(_BaseType): + types: list[str] + + def __init__(self, json_schema_object: dict[str, Any]): + super().__init__(json_schema_object) + + mapping = { + "string": "string", + "number": "number", + "integer": "number", + "boolean": "boolean", + "null": "null", + "object": "{}", + "array": "Array", + } + self.types = [mapping[t] for t in json_schema_object["type"]] + + def to_typescript_style(self, indent: str = "") -> str: + return " | ".join(self.types) + + +class _ParameterTypeRef(_BaseType): + ref_name: str + is_self_ref: bool = False + + def __init__(self, json_schema_object: dict[str, Any], + registry: _SchemaRegistry): + super().__init__(json_schema_object) + + ref = json_schema_object["$ref"] + resolved_schema = registry.resolve_ref(ref) + + if resolved_schema.get("$self_ref", False): + self.ref_name = "parameters" + self.is_self_ref = True + else: + self.ref_name = ref.split("/")[-1] + + def to_typescript_style(self, indent: str = "") -> str: + return self.ref_name + + +_ParameterType = (_ParameterTypeScalar + | _ParameterTypeObject + | _ParameterTypeArray + | _ParameterTypeEnum + | _ParameterTypeAnyOf + | _ParameterTypeUnion + | _ParameterTypeRef) + + +@dataclasses.dataclass +class _Parameter: + """ + A parameter in a function, or a field in a object. + It consists of the type as well as the name. + """ + + type: _ParameterType + name: str = "_" + optional: bool = True + default: Any | None = None + + @classmethod + def parse_extended(cls, attributes: dict[str, Any]) -> "_Parameter": + if not attributes: + raise ValueError("attributes is empty") + + return cls( + name=attributes.get("name", "_"), + type=_parse_parameter_type(attributes), + optional=attributes.get("optional", False), + default=attributes.get("default"), + ) + + def to_typescript_style(self, indent: str = "") -> str: + comments = self.type.format_docstring(indent) + + if self.default is not None: + default_repr = (json.dumps(self.default, ensure_ascii=False) + if not isinstance(self.default, (int, float, bool)) + else repr(self.default)) + comments += f"{indent}// Default: {default_repr}\n" + + return ( + comments + + f"{indent}{self.name}{'?' if self.optional else ''}: {self.type.to_typescript_style(indent=indent)}" + ) + + +def _parse_parameter_type( + json_schema_object: dict[str, Any] | bool, + registry: _SchemaRegistry | None = None) -> _ParameterType: + if isinstance(json_schema_object, bool): + if json_schema_object: + return _ParameterTypeScalar(type="any") + else: + logger.warning( + f"Warning: Boolean value {json_schema_object} is not supported, use null instead." + ) + return _ParameterTypeScalar(type="null") + + if "$ref" in json_schema_object and registry: + return _ParameterTypeRef(json_schema_object, registry) + + if "anyOf" in json_schema_object: + return _ParameterTypeAnyOf(json_schema_object, registry) + elif "enum" in json_schema_object: + return _ParameterTypeEnum(json_schema_object) + elif "type" in json_schema_object: + typ = json_schema_object["type"] + if isinstance(typ, list): + return _ParameterTypeUnion(json_schema_object) + elif typ == "object": + return _ParameterTypeObject(json_schema_object, registry) + elif typ == "array": + return _ParameterTypeArray(json_schema_object, registry) + else: + return _ParameterTypeScalar(typ, json_schema_object) + elif json_schema_object == {}: + return _ParameterTypeScalar(type="any") + else: + raise ValueError(f"Invalid JSON Schema object: {json_schema_object}") + + +def _openai_function_to_typescript_style(function: dict[str, Any], ) -> str: + """Convert OpenAI function definition (dict) to TypeScript style string.""" + registry = _SchemaRegistry() + parameters = function.get("parameters") or {} + parsed = _ParameterTypeObject(parameters, registry) + + interfaces = [] + root_interface_name = None + if registry.has_self_ref: + root_interface_name = "parameters" + params_str = _TS_FIELD_DELIMITER.join([ + p.to_typescript_style(indent=_TS_INDENT) for p in parsed.properties + ]) + params_str = f"\n{params_str}\n" if params_str else "" + interface_def = f"interface {root_interface_name} {{{params_str}}}" + interfaces.append(interface_def) + + definitions_copy = dict(registry.definitions) + for def_name, def_schema in definitions_copy.items(): + obj_type = _parse_parameter_type(def_schema, registry) + params_str = obj_type.to_typescript_style() + + description_part = "" + if obj_description := def_schema.get("description", ""): + description_part = _format_description(obj_description) + "\n" + + interface_def = f"{description_part}interface {def_name} {params_str}" + interfaces.append(interface_def) + + interface_str = "\n".join(interfaces) + function_name = function.get("name", "function") + if root_interface_name: + type_def = f"type {function_name} = (_: {root_interface_name}) => any;" + else: + params_str = parsed.to_typescript_style() + type_def = f"type {function_name} = (_: {params_str}) => any;" + + description = function.get("description") + return "\n".join( + filter( + bool, + [ + interface_str, + ((description and _format_description(description)) or ""), + type_def, + ], + )) + + +def encode_tools_to_typescript_style(tools: list[dict[str, Any]], ) -> str: + """ + Convert tools (list of dict) to TypeScript style string. + + Supports OpenAI format: {"type": "function", "function": {...}} + + Args: + tools: List of tool definitions in dict format + + Returns: + TypeScript style string representation of the tools + """ + if not tools: + return "" + + functions = [] + + for tool in tools: + tool_type = tool.get("type") + if tool_type == "function": + func_def = tool.get("function", {}) + if func_def: + functions.append( + _openai_function_to_typescript_style(func_def)) + else: + # Skip unsupported tool types (like "_plugin") + continue + + if not functions: + return "" + + functions_str = "\n".join(functions) + result = "# Tools\n\n" + + if functions_str: + result += "## functions\nnamespace functions {\n" + result += functions_str + "\n" + result += "}\n" + + return result