diff --git a/ms-swift/examples/train/multi-gpu/device_map/train.sh b/ms-swift/examples/train/multi-gpu/device_map/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..3c71d935e7612a08aff0085bf0480a774dcbb5d0 --- /dev/null +++ b/ms-swift/examples/train/multi-gpu/device_map/train.sh @@ -0,0 +1,25 @@ +# 2 * 76GiB +CUDA_VISIBLE_DEVICES=0,1 \ +MAX_PIXELS=1003520 \ +swift sft \ + --model Qwen/Qwen2.5-VL-72B-Instruct \ + --dataset 'modelscope/coco_2014_caption:validation#20000' \ + --train_type lora \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --freeze_vit true \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 diff --git a/ms-swift/examples/train/multimodal/grounding.sh b/ms-swift/examples/train/multimodal/grounding.sh new file mode 100644 index 0000000000000000000000000000000000000000..bad1f37d293feba5035ca522d37622aa5daf6f64 --- /dev/null +++ b/ms-swift/examples/train/multimodal/grounding.sh @@ -0,0 +1,27 @@ +# 20GiB +# You can refer to `https://github.com/QwenLM/Qwen2.5-VL` for the meaning of the `MAX_PIXELS` parameter. +CUDA_VISIBLE_DEVICES=0 \ +MAX_PIXELS=1003520 \ +swift sft \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --dataset 'AI-ModelScope/coco#20000' \ + --train_type lora \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --freeze_vit true \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --dataset_num_proc 4 diff --git a/ms-swift/examples/train/multimodal/lora_llm_full_vit/sft.sh b/ms-swift/examples/train/multimodal/lora_llm_full_vit/sft.sh new file mode 100644 index 0000000000000000000000000000000000000000..d0ef05550b7a3d89b61a2814c79be72d5c5afede --- /dev/null +++ b/ms-swift/examples/train/multimodal/lora_llm_full_vit/sft.sh @@ -0,0 +1,30 @@ +# 4 * 22GiB +# vit/merger lr 1e-5; llm lora lr 1e-4 +NPROC_PER_NODE=4 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +MAX_PIXELS=1003520 \ +swift sft \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset 'AI-ModelScope/coco#20000' \ + --train_type custom \ + --optimizer custom \ + --external_plugins 'examples/train/multimodal/lora_llm_full_vit/custom_plugin.py' \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 16 \ + --lora_alpha 32 \ + --gradient_accumulation_steps 4 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 8192 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --dataset_num_proc 4 \ + --deepspeed zero2 \ + --save_only_model true diff --git a/ms-swift/examples/train/multimodal/rlhf/dpo.sh b/ms-swift/examples/train/multimodal/rlhf/dpo.sh new file mode 100644 index 0000000000000000000000000000000000000000..6817f5e351ca03c587c1f209446f1be66ed8bf15 --- /dev/null +++ b/ms-swift/examples/train/multimodal/rlhf/dpo.sh @@ -0,0 +1,33 @@ +# 4*50GiB +# You can refer to `https://github.com/QwenLM/Qwen2.5-VL` for the meaning of the `MAX_PIXELS` parameter. +# --rlhf_type cpo/orpo/simpo/rm are also supported +nproc_per_node=2 + +CUDA_VISIBLE_DEVICES=0,1 \ +NPROC_PER_NODE=$nproc_per_node \ +MAX_PIXELS=1003520 \ +swift rlhf \ + --rlhf_type dpo \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset 'swift/RLAIF-V-Dataset#20000' \ + --train_type lora \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --freeze_vit true \ + --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --deepspeed zero2 \ + --logging_steps 5 \ + --max_length 4096 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --dataset_num_proc 4 diff --git a/ms-swift/examples/train/rlhf/ppo.sh b/ms-swift/examples/train/rlhf/ppo.sh new file mode 100644 index 0000000000000000000000000000000000000000..d152d1f153ecc99aec78fcfb664060e770e6f07e --- /dev/null +++ b/ms-swift/examples/train/rlhf/ppo.sh @@ -0,0 +1,33 @@ +# Currently, it only supports the case where the model and reward_model use the same template/tokenizer. +# Currently, multimodal model PPO is not supported. +nproc_per_node=4 + +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +NPROC_PER_NODE=$nproc_per_node \ +swift rlhf \ + --rlhf_type ppo \ + --model LLM-Research/Meta-Llama-3.1-8B-Instruct \ + --reward_model 'AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2' \ + --train_type lora \ + --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#20000' 'AI-ModelScope/alpaca-gpt4-data-en#20000' \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-5 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --deepspeed zero2 \ + --response_length 512 \ + --temperature 0.7 \ + --dataset_num_proc 4 diff --git a/ms-swift/examples/train/seq_cls/qwen2_5/sft.sh b/ms-swift/examples/train/seq_cls/qwen2_5/sft.sh new file mode 100644 index 0000000000000000000000000000000000000000..fe33ee37290b5a70da31db0bbff1ad4e7d0786fc --- /dev/null +++ b/ms-swift/examples/train/seq_cls/qwen2_5/sft.sh @@ -0,0 +1,28 @@ +# If `num_labels` is provided, it will be considered a classification task, +# and AutoModelForSequenceClassification will be used to load the model. +# You can also specify `--model Qwen/Qwen2.5-0.5B-Instruct --use_chat_template true`. +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-0.5B \ + --train_type lora \ + --dataset 'DAMO_NLP/jd:cls#2000' \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --gradient_accumulation_steps 16 \ + --eval_steps 50 \ + --save_steps 50 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --num_labels 2 \ + --task_type seq_cls \ + --use_chat_template false diff --git a/ms-swift/examples/train/seq_cls/qwen2_vl/infer.sh b/ms-swift/examples/train/seq_cls/qwen2_vl/infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..a985418e46b35aa084f0a5b37a6649c191cce515 --- /dev/null +++ b/ms-swift/examples/train/seq_cls/qwen2_vl/infer.sh @@ -0,0 +1,5 @@ +CUDA_VISIBLE_DEVICES=0 \ +MAX_PIXELS=1003520 \ +swift infer \ + --adapters output/vx-xxx/checkpoint-xxx \ + --load_data_args true diff --git a/ms-swift/examples/train/tuners/adapter/train.sh b/ms-swift/examples/train/tuners/adapter/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..d334ae6cba0c57c8ea3553a57e1d4b774840ffd7 --- /dev/null +++ b/ms-swift/examples/train/tuners/adapter/train.sh @@ -0,0 +1,16 @@ +# 17GiB +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --train_type adapter \ + --dataset 'swift/self-cognition#1000' \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --model_author swift \ + --model_name swift-robot diff --git a/ms-swift/examples/train/tuners/boft/train.sh b/ms-swift/examples/train/tuners/boft/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..d06b865dd60124750d0e3599193cfac1f5a5ac73 --- /dev/null +++ b/ms-swift/examples/train/tuners/boft/train.sh @@ -0,0 +1,16 @@ +# 17GiB +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --train_type boft \ + --dataset 'swift/self-cognition#1000' \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --model_author swift \ + --model_name swift-robot diff --git a/ms-swift/examples/train/tuners/dora/train.sh b/ms-swift/examples/train/tuners/dora/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..2bc7d9f236470fb504e9fe26a8f2a53d2c6a9679 --- /dev/null +++ b/ms-swift/examples/train/tuners/dora/train.sh @@ -0,0 +1,19 @@ +# 17.2GiB +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --train_type lora \ + --use_dora true \ + --dataset 'swift/self-cognition#1000' \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --model_author swift \ + --model_name swift-robot diff --git a/ms-swift/examples/train/tuners/galore/train_galore.sh b/ms-swift/examples/train/tuners/galore/train_galore.sh new file mode 100644 index 0000000000000000000000000000000000000000..4728e0e492c0b4bc091529810e7a73ccb8d6e1f5 --- /dev/null +++ b/ms-swift/examples/train/tuners/galore/train_galore.sh @@ -0,0 +1,18 @@ +# 38GiB +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --train_type full \ + --dataset 'swift/self-cognition#1000' \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-5 \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --model_author swift \ + --model_name swift-robot \ + --use_galore true \ + --galore_optim_per_parameter true diff --git a/ms-swift/examples/train/tuners/llamapro/train.sh b/ms-swift/examples/train/tuners/llamapro/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..d0956449d8d09ebfa15b3224cdd78909f9c09c63 --- /dev/null +++ b/ms-swift/examples/train/tuners/llamapro/train.sh @@ -0,0 +1,17 @@ +# 25.4GiB +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --train_type llamapro \ + --dataset 'swift/self-cognition#1000' \ + --llamapro_num_new_blocks 4 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --model_author swift \ + --model_name swift-robot diff --git a/ms-swift/examples/train/tuners/olora/train.sh b/ms-swift/examples/train/tuners/olora/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..1ead995df6c2b933568787c81da262ee12c0c024 --- /dev/null +++ b/ms-swift/examples/train/tuners/olora/train.sh @@ -0,0 +1,19 @@ +# 17GiB +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --train_type lora \ + --dataset 'swift/self-cognition#1000' \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --init_lora_weights olora \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --model_author swift \ + --model_name swift-robot diff --git a/ms-swift/examples/train/tuners/pissa/train.sh b/ms-swift/examples/train/tuners/pissa/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..9139ba441e5ff06339845f0c36256940167fc43d --- /dev/null +++ b/ms-swift/examples/train/tuners/pissa/train.sh @@ -0,0 +1,19 @@ +# 17GiB +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --train_type lora \ + --dataset 'swift/self-cognition#1000' \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --init_lora_weights pissa \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --model_author swift \ + --model_name swift-robot diff --git a/ms-swift/examples/train/tuners/qlora/train.sh b/ms-swift/examples/train/tuners/qlora/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..71684537485693088f485b0b7ff97a412ee6e37c --- /dev/null +++ b/ms-swift/examples/train/tuners/qlora/train.sh @@ -0,0 +1,19 @@ +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --train_type lora \ + --dataset 'swift/self-cognition#1000' \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --model_author swift \ + --model_name swift-robot \ + --quant_bits 4 \ + --quant_method bnb diff --git a/ms-swift/examples/train/tuners/reft/train.sh b/ms-swift/examples/train/tuners/reft/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..0b3853bfe9c70e33898b46af6a55e4c2ecb464b1 --- /dev/null +++ b/ms-swift/examples/train/tuners/reft/train.sh @@ -0,0 +1,17 @@ +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --train_type reft \ + --dataset 'swift/self-cognition#1000' \ + --reft_intervention_type 'LoreftIntervention' \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --gradient_checkpointing false \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --model_author swift \ + --model_name swift-robot diff --git a/ms-swift/ms_swift.egg-info/PKG-INFO b/ms-swift/ms_swift.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..e931a54d4239ca526694cb74b53a16d1c8ce7a0d --- /dev/null +++ b/ms-swift/ms_swift.egg-info/PKG-INFO @@ -0,0 +1,545 @@ +Metadata-Version: 2.4 +Name: ms_swift +Version: 3.5.0.dev0 +Summary: Swift: Scalable lightWeight Infrastructure for Fine-Tuning +Home-page: https://github.com/modelscope/swift +Author: DAMO ModelScope teams +Author-email: contact@modelscope.cn +License: Apache License 2.0 +Keywords: python,petl,efficient tuners +Platform: UNKNOWN +Classifier: Development Status :: 4 - Beta +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: accelerate +Requires-Dist: addict +Requires-Dist: aiohttp +Requires-Dist: attrdict +Requires-Dist: binpacking +Requires-Dist: charset_normalizer +Requires-Dist: cpm_kernels +Requires-Dist: dacite +Requires-Dist: datasets<3.4,>=3.0 +Requires-Dist: einops +Requires-Dist: fastapi +Requires-Dist: gradio>=3.40.0 +Requires-Dist: importlib_metadata +Requires-Dist: jieba +Requires-Dist: matplotlib +Requires-Dist: modelscope>=1.23 +Requires-Dist: nltk +Requires-Dist: numpy<2.0 +Requires-Dist: openai +Requires-Dist: oss2 +Requires-Dist: pandas +Requires-Dist: peft<0.16,>=0.11 +Requires-Dist: pillow +Requires-Dist: requests +Requires-Dist: rouge +Requires-Dist: safetensors +Requires-Dist: scipy +Requires-Dist: sentencepiece +Requires-Dist: simplejson>=3.3.0 +Requires-Dist: sortedcontainers>=1.5.9 +Requires-Dist: tensorboard +Requires-Dist: tiktoken +Requires-Dist: tqdm +Requires-Dist: transformers<4.53,>=4.33 +Requires-Dist: transformers_stream_generator +Requires-Dist: trl<0.18,>=0.13 +Requires-Dist: uvicorn +Requires-Dist: zstandard +Provides-Extra: eval +Requires-Dist: evalscope[opencompass]; extra == "eval" +Requires-Dist: evalscope[vlmeval]; extra == "eval" +Provides-Extra: swanlab +Requires-Dist: swanlab; extra == "swanlab" +Provides-Extra: seq-parallel +Requires-Dist: xtuner; extra == "seq-parallel" +Provides-Extra: all +Requires-Dist: accelerate; extra == "all" +Requires-Dist: addict; extra == "all" +Requires-Dist: aiohttp; extra == "all" +Requires-Dist: attrdict; extra == "all" +Requires-Dist: binpacking; extra == "all" +Requires-Dist: charset_normalizer; extra == "all" +Requires-Dist: cpm_kernels; extra == "all" +Requires-Dist: dacite; extra == "all" +Requires-Dist: datasets<3.4,>=3.0; extra == "all" +Requires-Dist: einops; extra == "all" +Requires-Dist: fastapi; extra == "all" +Requires-Dist: gradio>=3.40.0; extra == "all" +Requires-Dist: importlib_metadata; extra == "all" +Requires-Dist: jieba; extra == "all" +Requires-Dist: matplotlib; extra == "all" +Requires-Dist: modelscope>=1.23; extra == "all" +Requires-Dist: nltk; extra == "all" +Requires-Dist: numpy<2.0; extra == "all" +Requires-Dist: openai; extra == "all" +Requires-Dist: oss2; extra == "all" +Requires-Dist: pandas; extra == "all" +Requires-Dist: peft<0.16,>=0.11; extra == "all" +Requires-Dist: pillow; extra == "all" +Requires-Dist: requests; extra == "all" +Requires-Dist: rouge; extra == "all" +Requires-Dist: safetensors; extra == "all" +Requires-Dist: scipy; extra == "all" +Requires-Dist: sentencepiece; extra == "all" +Requires-Dist: simplejson>=3.3.0; extra == "all" +Requires-Dist: sortedcontainers>=1.5.9; extra == "all" +Requires-Dist: tensorboard; extra == "all" +Requires-Dist: tiktoken; extra == "all" +Requires-Dist: tqdm; extra == "all" +Requires-Dist: transformers<4.53,>=4.33; extra == "all" +Requires-Dist: transformers_stream_generator; extra == "all" +Requires-Dist: trl<0.18,>=0.13; extra == "all" +Requires-Dist: uvicorn; extra == "all" +Requires-Dist: zstandard; extra == "all" +Requires-Dist: evalscope[opencompass]; extra == "all" +Requires-Dist: evalscope[vlmeval]; extra == "all" +Requires-Dist: xtuner; extra == "all" +Requires-Dist: swanlab; extra == "all" +Dynamic: author +Dynamic: author-email +Dynamic: classifier +Dynamic: description +Dynamic: description-content-type +Dynamic: home-page +Dynamic: keywords +Dynamic: license +Dynamic: license-file +Dynamic: provides-extra +Dynamic: requires-dist +Dynamic: summary + +# SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning) + +

+
+ +
+

+

+ModelScope Community Website +
+ 中文   |   English   +

+ +

+ + + + + + + +

+ +

+modelscope%2Fswift | Trendshift +

+ +

+ Paper   | English Documentation   |   中文文档   +

+ +## 📖 Table of Contents +- [Groups](#-Groups) +- [Introduction](#-introduction) +- [News](#-news) +- [Installation](#%EF%B8%8F-installation) +- [Quick Start](#-quick-Start) +- [Usage](#-Usage) +- [License](#-License) +- [Citation](#-citation) + + +## ☎ Groups + +You can contact us and communicate with us by adding our group: + + +[Discord Group](https://discord.com/invite/D27yfEFVz5) | WeChat Group +:-------------------------:|:-------------------------: + | + + +## 📝 Introduction +🍲 ms-swift is an official framework provided by the ModelScope community for fine-tuning and deploying large language models and multi-modal large models. It currently supports the training (pre-training, fine-tuning, human alignment), inference, evaluation, quantization, and deployment of 500+ large models and 200+ multi-modal large models. These large language models (LLMs) include models such as Qwen3, Qwen3-MoE, Qwen2.5, InternLM3, GLM4, Mistral, DeepSeek-R1, Yi1.5, TeleChat2, Baichuan2, and Gemma2. The multi-modal LLMs include models such as Qwen2.5-VL, Qwen2-Audio, Llama3.4, Llava, InternVL2.5, MiniCPM-V-2.6, GLM4v, Xcomposer2.5, Yi-VL, DeepSeek-VL2, Phi3.5-Vision, and GOT-OCR2. + +🍔 Additionally, ms-swift incorporates the latest training technologies, including lightweight techniques such as LoRA, QLoRA, Llama-Pro, LongLoRA, GaLore, Q-GaLore, LoRA+, LISA, DoRA, FourierFt, ReFT, UnSloth, and Liger, as well as human alignment training methods like DPO, GRPO, RM, PPO, KTO, CPO, SimPO, and ORPO. ms-swift supports acceleration of inference, evaluation, and deployment modules using vLLM and LMDeploy, and it supports model quantization with technologies like GPTQ, AWQ, and BNB. Furthermore, ms-swift offers a Gradio-based Web UI and a wealth of best practices. + +**Why choose ms-swift?** + +- 🍎 **Model Types**: Supports 500+ pure text large models, **200+ multi-modal large models**, as well as All-to-All multi-modal models, sequence classification models, and embedding models, **covering the entire process from training to deployment**. +- **Dataset Types**: Comes with 150+ pre-training, fine-tuning, human alignment, multi-modal datasets, and supports custom datasets. +- **Hardware Support**: Compatible with CPU, RTX series, T4/V100, A10/A100/H100, Ascend NPU, MPS, etc. +- 🍊 **Lightweight Training**: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel. +- **Distributed Training**: Supports distributed data parallel (DDP), device_map simple model parallelism, DeepSpeed ZeRO2/ZeRO3, FSDP, and other distributed training techniques. +- **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ. +- **RLHF Training**: Supports human alignment training methods such as DPO, GRPO, RM, PPO, KTO, CPO, SimPO, ORPO for both pure text and multi-modal large models. +- 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding. +- **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline. +- **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer. +- 🍉 **Toolbox Capabilities**: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment. +- **Inference Acceleration**: Supports inference acceleration engines like PyTorch, vLLM, LmDeploy, and provides OpenAI API for accelerating inference, deployment, and evaluation modules. +- **Model Evaluation**: Uses EvalScope as the evaluation backend and supports evaluation on 100+ datasets for both pure text and multi-modal models. +- **Model Quantization**: Supports AWQ, GPTQ, and BNB quantized exports, with models that can use vLLM/LmDeploy for inference acceleration and continue training. + + +## 🎉 News +- 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models) . +- 🎁 2025.04.15: The ms-swift paper has been accepted by AAAI 2025. You can find the paper at [this link](https://ojs.aaai.org/index.php/AAAI/article/view/35383). +- 🎁 2025.03.23: Multi-round GRPO is now supported for training multi-turn dialogue scenarios (e.g., agent tool calling). Please refer to the [training script](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_multi_round.sh). +- 🎁 2025.03.16: Support for Megatron's parallel training techniques is now available. Please see the [Megatron-SWIFT training documentation](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html). +- 🎁 2025.03.15: Fine-tuning of embedding models for both pure text and multimodal models is supported. Please check the [training script](https://idealab.alibaba-inc.com/examples/train/embedding). +- 🎁 2025.03.05: The hybrid mode for GRPO is supported, with a script for training a 72B model on 4 GPUs (4*80G) available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_72b_4gpu.sh). Tensor parallelism with vllm is also supported, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/multi_gpu_mp_colocate.sh). +- 🎁 2025.02.21: The GRPO algorithm now supports LMDeploy, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/full_lmdeploy.sh). Additionally, the performance of the GRPO algorithm has been tested, achieving a training speed increase of up to 300% using various tricks. Please check the WanDB table [here](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz). +- 🎁 2025.02.21: The `swift sample` command is now supported. The reinforcement fine-tuning script can be found [here](https://idealab.alibaba-inc.com/docs/source/Instruction/强化微调.md), and the large model API distillation sampling script is available [here](https://idealab.alibaba-inc.com/examples/sampler/distill/distill.sh). +- 🔥 2025.02.12: Support for the GRPO (Group Relative Policy Optimization) training algorithm has been added. Documentation is available [here](https://idealab.alibaba-inc.com/docs/source/Instruction/GRPO.md). +- 🎁 2024.12.04: Major update to **ms-swift 3.0**. Please refer to the [release notes and changes](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html). +
More + +- 🎉 2024.08.12: The ms-swift paper has been published on arXiv and can be read [here](https://arxiv.org/abs/2408.05517). +- 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models. +- 🔥 2024.07.29: Support for using [vllm](https://github.com/vllm-project/vllm) and [lmdeploy](https://github.com/InternLM/lmdeploy) to accelerate inference for large models and multimodal models. When performing infer/deploy/eval, you can specify `--infer_backend vllm/lmdeploy`. +- 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM/PPO. +- 🔥 2024.02.01: Support for Agent training! The training algorithm is derived from [this paper](https://arxiv.org/pdf/2309.00986.pdf). +
+ +## 🛠️ Installation +To install using pip: +```shell +pip install ms-swift -U +``` + +To install from source: +```shell +# pip install git+https://github.com/modelscope/ms-swift.git + +git clone https://github.com/modelscope/ms-swift.git +cd ms-swift +pip install -e . +``` + +Running Environment: + +| | Range | Recommended | Notes | +| ------------ |--------------| ----------- | ----------------------------------------- | +| python | >=3.9 | 3.10 | | +| cuda | | cuda12 | No need to install if using CPU, NPU, MPS | +| torch | >=2.0 | | | +| transformers | >=4.33 | 4.51 | | +| modelscope | >=1.23 | | | +| peft | >=0.11,<0.16 | || +| trl | >=0.13,<0.18 | 0.17 |RLHF| +| deepspeed | >=0.14 | 0.14.5 | Training | +| vllm | >=0.5.1 | 0.7.3/0.8 | Inference/Deployment/Evaluation | +| lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation | +| evalscope | >=0.11 | | Evaluation | + +For more optional dependencies, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh). + + +## 🚀 Quick Start + +10 minutes of self-cognition fine-tuning of Qwen2.5-7B-Instruct on a single 3090 GPU: + +### Command Line Interface + +```shell +# 22GB +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --train_type lora \ + --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \ + 'AI-ModelScope/alpaca-gpt4-data-en#500' \ + 'swift/self-cognition#500' \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --gradient_accumulation_steps 16 \ + --eval_steps 50 \ + --save_steps 50 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --model_author swift \ + --model_name swift-robot +``` + +Tips: + +- If you want to train with a custom dataset, you can refer to [this guide](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) to organize your dataset format and specify `--dataset `. +- The `--model_author` and `--model_name` parameters are only effective when the dataset includes `swift/self-cognition`. +- To train with a different model, simply modify `--model `. +- By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`. + +After training is complete, use the following command to infer with the trained weights: + +- Here, `--adapters` should be replaced with the last checkpoint folder generated during training. Since the adapters folder contains the training parameter file `args.json`, there is no need to specify `--model`, `--system` separately; Swift will automatically read these parameters. To disable this behavior, you can set `--load_args false`. + +```shell +# Using an interactive command line for inference. +CUDA_VISIBLE_DEVICES=0 \ +swift infer \ + --adapters output/vx-xxx/checkpoint-xxx \ + --stream true \ + --temperature 0 \ + --max_new_tokens 2048 + +# merge-lora and use vLLM for inference acceleration +CUDA_VISIBLE_DEVICES=0 \ +swift infer \ + --adapters output/vx-xxx/checkpoint-xxx \ + --stream true \ + --merge_lora true \ + --infer_backend vllm \ + --max_model_len 8192 \ + --temperature 0 \ + --max_new_tokens 2048 +``` + +Finally, use the following command to push the model to ModelScope: + +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift export \ + --adapters output/vx-xxx/checkpoint-xxx \ + --push_to_hub true \ + --hub_model_id '' \ + --hub_token '' \ + --use_hf false +``` + + +### Web-UI +The Web-UI is a **zero-threshold** training and deployment interface solution based on Gradio interface technology. For more details, you can check [here](https://swift.readthedocs.io/en/latest/GetStarted/Web-UI.html). + +```shell +SWIFT_UI_LANG=en swift web-ui +``` + +![image.png](./docs/resources/web-ui-en.jpg) + +### Using Python + +ms-swift also supports training and inference using Python. Below is pseudocode for training and inference. For more details, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb). + +Training: + +```python +# Retrieve the model and template, and add a trainable LoRA module +model, tokenizer = get_model_tokenizer(model_id_or_path, ...) +template = get_template(model.model_meta.template, tokenizer, ...) +model = Swift.prepare_model(model, lora_config) + +# Download and load the dataset, and encode the text into tokens +train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...) +train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc) +val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc) + +# Train the model +trainer = Seq2SeqTrainer( + model=model, + args=training_args, + data_collator=template.data_collator, + train_dataset=train_dataset, + eval_dataset=val_dataset, + template=template, +) +trainer.train() +``` +Inference: + +```python +# Perform inference using the native PyTorch engine +engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint]) +infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}]) +request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature) + +resp_list = engine.infer([infer_request], request_config) +print(f'response: {resp_list[0].choices[0].message.content}') +``` + +## ✨ Usage +Here is a minimal example of training to deployment using ms-swift. For more details, you can check the [examples](https://github.com/modelscope/ms-swift/tree/main/examples). + +- If you want to use other models or datasets (including multimodal models and datasets), you only need to modify `--model` to specify the corresponding model's ID or path, and modify `--dataset` to specify the corresponding dataset's ID or path. +- By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`. + +| Useful Links | +| ------ | +| [🔥Command Line Parameters](https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html) | +| [Supported Models and Datasets](https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html) | +| [Custom Models](https://swift.readthedocs.io/en/latest/Customization/Custom-model.html), [🔥Custom Datasets](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) | +| [LLM Tutorial](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) | + +### Training + +Supported Training Methods: + +| Method | Full-Parameter | LoRA | QLoRA | Deepspeed | Multi-Node | Multi-Modal | +|------------------------------------|--------------------------------------------------------------|---------------------------------------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|----------------------------------------------------------------------------------------------| +| Pre-training | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | +| Instruction Supervised Fine-tuning | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) | +| DPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) | +| GRPO Training | [✅]((https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh)) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ | +| Reward Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ | +| PPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ | +| KTO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) | +| CPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ | +| SimPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ | +| ORPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ | +| Classification Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) | +| Embedding Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) | + + + +Pre-training: +```shell +# 8*A100 +NPROC_PER_NODE=8 \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +swift pt \ + --model Qwen/Qwen2.5-7B \ + --dataset swift/chinese-c4 \ + --streaming true \ + --train_type full \ + --deepspeed zero2 \ + --output_dir output \ + --max_steps 10000 \ + ... +``` + +Fine-tuning: +```shell +CUDA_VISIBLE_DEVICES=0 swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --dataset AI-ModelScope/alpaca-gpt4-data-en \ + --train_type lora \ + --output_dir output \ + ... +``` + +RLHF: +```shell +CUDA_VISIBLE_DEVICES=0 swift rlhf \ + --rlhf_type dpo \ + --model Qwen/Qwen2.5-7B-Instruct \ + --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \ + --train_type lora \ + --output_dir output \ + ... +``` + + +### Inference +```shell +CUDA_VISIBLE_DEVICES=0 swift infer \ + --model Qwen/Qwen2.5-7B-Instruct \ + --stream true \ + --infer_backend pt \ + --max_new_tokens 2048 + +# LoRA +CUDA_VISIBLE_DEVICES=0 swift infer \ + --model Qwen/Qwen2.5-7B-Instruct \ + --adapters swift/test_lora \ + --stream true \ + --infer_backend pt \ + --temperature 0 \ + --max_new_tokens 2048 +``` + +### Interface Inference +```shell +CUDA_VISIBLE_DEVICES=0 swift app \ + --model Qwen/Qwen2.5-7B-Instruct \ + --stream true \ + --infer_backend pt \ + --max_new_tokens 2048 +``` + +### Deployment +```shell +CUDA_VISIBLE_DEVICES=0 swift deploy \ + --model Qwen/Qwen2.5-7B-Instruct \ + --infer_backend vllm +``` + +### Sampling +```shell +CUDA_VISIBLE_DEVICES=0 swift sample \ + --model LLM-Research/Meta-Llama-3.1-8B-Instruct \ + --sampler_engine pt \ + --num_return_sequences 5 \ + --dataset AI-ModelScope/alpaca-gpt4-data-zh#5 +``` + +### Evaluation +```shell +CUDA_VISIBLE_DEVICES=0 swift eval \ + --model Qwen/Qwen2.5-7B-Instruct \ + --infer_backend lmdeploy \ + --eval_backend OpenCompass \ + --eval_dataset ARC_c +``` + +### Quantization +```shell +CUDA_VISIBLE_DEVICES=0 swift export \ + --model Qwen/Qwen2.5-7B-Instruct \ + --quant_bits 4 --quant_method awq \ + --dataset AI-ModelScope/alpaca-gpt4-data-zh \ + --output_dir Qwen2.5-7B-Instruct-AWQ +``` + +### Push Model +```shell +swift export \ + --model \ + --push_to_hub true \ + --hub_model_id '' \ + --hub_token '' +``` + +## 🏛 License + +This framework is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). For models and datasets, please refer to the original resource page and follow the corresponding License. + +## 📎 Citation + +```bibtex +@misc{zhao2024swiftascalablelightweightinfrastructure, + title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning}, + author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen}, + year={2024}, + eprint={2408.05517}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2408.05517}, +} +``` + +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=modelscope/swift&type=Date)](https://star-history.com/#modelscope/ms-swift&Date) diff --git a/ms-swift/ms_swift.egg-info/not-zip-safe b/ms-swift/ms_swift.egg-info/not-zip-safe new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/ms-swift/ms_swift.egg-info/not-zip-safe @@ -0,0 +1 @@ + diff --git a/ms-swift/requirements/install_all.sh b/ms-swift/requirements/install_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..3bcce3ffc89b0bd8e566736a6ddbbb02c2679d98 --- /dev/null +++ b/ms-swift/requirements/install_all.sh @@ -0,0 +1,12 @@ +# please use python=3.10, cuda12.* +# sh requirements/install_all.sh +pip install "vllm>=0.5.1" -U +pip install "lmdeploy>=0.5" -U --no-deps +pip install autoawq -U --no-deps +pip install auto_gptq optimum bitsandbytes -U +pip install git+https://github.com/modelscope/ms-swift.git +pip install timm -U +pip install deepspeed -U +pip install qwen_vl_utils qwen_omni_utils decord librosa pyav icecream soundfile -U +pip install liger_kernel nvitop pre-commit -U +# flash-attn: https://github.com/Dao-AILab/flash-attention/releases diff --git a/ms-swift/requirements/seq_parallel.txt b/ms-swift/requirements/seq_parallel.txt new file mode 100644 index 0000000000000000000000000000000000000000..9a09f5d6292273d9e890431c40c158de1d5a6d50 --- /dev/null +++ b/ms-swift/requirements/seq_parallel.txt @@ -0,0 +1 @@ +xtuner diff --git a/ms-swift/requirements/swanlab.txt b/ms-swift/requirements/swanlab.txt new file mode 100644 index 0000000000000000000000000000000000000000..85381455c6b3bb922c0c4c23a82810e3ddb41f04 --- /dev/null +++ b/ms-swift/requirements/swanlab.txt @@ -0,0 +1 @@ +swanlab diff --git a/ms-swift/scripts/benchmark/config/tuner.json b/ms-swift/scripts/benchmark/config/tuner.json new file mode 100644 index 0000000000000000000000000000000000000000..542fe90b8c60fb5454a6584cc80b2c9f4ff85102 --- /dev/null +++ b/ms-swift/scripts/benchmark/config/tuner.json @@ -0,0 +1,301 @@ +{ + "cmd": "sft", + "requirements":{ + "gpu": "1", + "ddp": "1" + }, + "eval_requirements": { + "gpu": "1" + }, + "eval_dataset": ["ceval", "gsm8k", "arc"], + "args": { + "model": "Qwen/Qwen-7B-Chat", + "dataset": "iic/ms_agent", + "per_device_train_batch_size": 1, + "max_length": 2048, + "loss_scale": "react", + "gradient_accumulation_steps": 16, + "learning_rate": 5e-5, + "attn_impl": "flash_attn", + "eval_steps": 2000, + "save_steps": 2000, + "num_train_epochs": 2, + "gradient_checkpointing": true, + "weight_decay": 0.01, + "warmup_ratio": 0.03, + "save_total_limit": 2, + "logging_steps": 10 + }, + "experiment": [ + { + "name": "lora", + "args": { + "train_type": "lora", + "lora_rank": 8, + "lora_alpha": 32 + } + }, + { + "name": "lora+packing", + "args": { + "train_type": "lora", + "lora_rank": 8, + "lora_alpha": 32, + "packing": true, + "eval_steps": 200, + "save_steps": 200 + } + }, + { + "name": "lora+packing+ddp", + "requirements":{ + "gpu": "2", + "ddp": "2" + }, + "args": { + "train_type": "lora", + "lora_rank": 8, + "lora_alpha": 32, + "packing": true, + "eval_steps": 100, + "save_steps": 100 + } + }, + { + "name": "lora+packing+lazytokenize", + "args": { + "train_type": "lora", + "lora_rank": 8, + "lora_alpha": 32, + "packing": true, + "lazy_tokenize": true, + "eval_steps": 200, + "save_steps": 200 + } + }, + { + "name": "lora+", + "args": { + "train_type": "lora", + "lora_rank": 8, + "lora_alpha": 32, + "lorap_lr_ratio": 16.0 + } + }, + { + "name": "rslora", + "args": { + "train_type": "lora", + "lora_rank": 8, + "lora_alpha": 32, + "use_rslora": true + } + }, + { + "name": "dora", + "args": { + "train_type": "lora", + "lora_rank": 8, + "lora_alpha": 32, + "use_dora": true + } + }, + { + "name": "lora+neftune", + "args": { + "train_type": "lora", + "lora_rank": 8, + "lora_alpha": 32, + "neftune_noise_alpha": 15.0 + } + }, + { + "name": "llamapro", + "args": { + "train_type": "llamapro", + "llamapro_num_new_blocks": "4" + } + }, + { + "name": "full", + "requirements":{ + "gpu": "1", + "ddp": "1" + }, + "args": { + "train_type": "full" + } + }, + { + "name": "reft", + "requirements":{ + "gpu": "1", + "ddp": "1" + }, + "args": { + "train_type": "reft", + "gradient_checkpointing": "false", + "loss_scale": "default" + } + }, + { + "name": "full+galore128+quantize", + "requirements":{ + "gpu": "1", + "ddp": "1" + }, + "args": { + "train_type": "full", + "use_galore": "true", + "galore_rank": "128", + "galore_update_proj_gap": "200", + "galore_optim_per_parameter": "false", + "galore_with_embedding": "false", + "galore_quantization": "true" + } + }, + { + "name": "full+galore128+quantize+proj_quant", + "requirements":{ + "gpu": "1", + "ddp": "1" + }, + "args": { + "train_type": "full", + "use_galore": "true", + "galore_rank": "128", + "galore_update_proj_gap": "200", + "galore_optim_per_parameter": "false", + "galore_with_embedding": "false", + "galore_quantization": "true", + "galore_proj_quant": "true" + } + }, + { + "name": "full+galore128", + "requirements":{ + "gpu": "1", + "ddp": "1" + }, + "args": { + "train_type": "full", + "use_galore": "true", + "galore_rank": "128", + "galore_update_proj_gap": "200", + "galore_optim_per_parameter": "false", + "galore_with_embedding": "false" + } + }, + { + "name": "full+galore64", + "requirements":{ + "gpu": "1", + "ddp": "1" + }, + "args": { + "train_type": "full", + "use_galore": "true", + "galore_rank": "64", + "galore_update_proj_gap": "200", + "galore_optim_per_parameter": "false", + "galore_with_embedding": "false" + } + }, + { + "name": "full+galore32", + "requirements":{ + "gpu": "1", + "ddp": "1" + }, + "args": { + "train_type": "full", + "use_galore": "true", + "galore_rank": "32", + "galore_update_proj_gap": "200", + "galore_optim_per_parameter": "false", + "galore_with_embedding": "false" + } + }, + { + "name": "full+galore_emb", + "requirements":{ + "gpu": "1", + "ddp": "1" + }, + "args": { + "train_type": "full", + "use_galore": "true", + "galore_rank": "128", + "galore_update_proj_gap": "200", + "galore_optim_per_parameter": "false", + "galore_with_embedding": "true" + } + }, + { + "name": "full+galore_perparam", + "requirements":{ + "gpu": "1", + "ddp": "1" + }, + "args": { + "train_type": "full", + "use_galore": "true", + "galore_rank": "128", + "galore_update_proj_gap": "200", + "galore_optim_per_parameter": "true", + "galore_with_embedding": "false" + } + }, + { + "name": "adalora", + "args": { + "train_type": "adalora", + "lora_rank": 8, + "lora_alpha": 32 + } + }, + { + "name": "adapter", + "args": { + "train_type": "adapter" + } + }, + { + "name": "full+lisa_2", + "info": "lisa 2layers + full", + "args": { + "train_type": "full", + "lisa_activated_layers": 2, + "lisa_step_interval": 20 + } + }, + { + "name": "full+lisa_4", + "info": "lisa 4layers + full", + "args": { + "train_type": "full", + "lisa_activated_layers": 4, + "lisa_step_interval": 20 + } + }, + { + "name": "unsloth+lora+q4", + "info": "unsloth lora quantization bit 4", + "args": { + "train_type": "lora", + "tuner_backend": "unsloth", + "quantization_bit": 4, + "model": "LLM-Research/Meta-Llama-3-8B-Instruct" + } + }, + { + "name": "unsloth+full", + "info": "unsloth full", + "args": { + "train_type": "full", + "tuner_backend": "unsloth", + "model_type": "LLM-Research/Meta-Llama-3-8B-Instruct" + } + } + ] +} diff --git a/ms-swift/scripts/benchmark/exp.py b/ms-swift/scripts/benchmark/exp.py new file mode 100644 index 0000000000000000000000000000000000000000..b71863ab5775daa3656b1601b807d5bd81951fef --- /dev/null +++ b/ms-swift/scripts/benchmark/exp.py @@ -0,0 +1,50 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import os +import os.path + +from exp_utils import ExpManager, find_all_config + +from swift.utils import * + +logger = get_logger() + + +def parse_args(): + parser = argparse.ArgumentParser(description='Simple args for swift experiments.') + parser.add_argument( + '--config', + type=str, + default=None, + required=True, + help='The experiment config file', + ) + parser.add_argument( + '--save_dir', + type=str, + default='./experiment', + required=False, + help='The experiment output folder', + ) + + args = parser.parse_args() + return args + + +def llm_exp(): + args = parse_args() + config: str = args.config + config = config.split(',') + os.makedirs(args.save_dir, exist_ok=True) + all_configs = [] + if not isinstance(config, list): + config = [config] + for dir_or_file in config: + all_configs.extend(find_all_config(dir_or_file)) + args.config = all_configs + exp_manager = ExpManager() + exp_manager.begin(args) + + +if __name__ == '__main__': + llm_exp() diff --git a/ms-swift/scripts/benchmark/generate_report.py b/ms-swift/scripts/benchmark/generate_report.py new file mode 100644 index 0000000000000000000000000000000000000000..6d618151d4365a8b06ccb68afb1e1a097c88c2e1 --- /dev/null +++ b/ms-swift/scripts/benchmark/generate_report.py @@ -0,0 +1,433 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import dataclasses +import os +from dataclasses import dataclass +from typing import Any, Dict, List + +import json +import numpy as np + +from swift.llm.template import split_str_parts_by + + +@dataclass +class ModelOutput: + + group: str = None + + name: str = None + + cmd: str = None + + requirements: Dict[str, str] = dataclasses.field(default_factory=dict) + + args: Dict[str, Any] = dataclasses.field(default_factory=dict) + + memory: str = None + + train_time: float = None + + train_samples: int = None + + train_samples_per_second: float = None + + last_model_checkpoint: str = None + + best_model_checkpoint: str = None + + best_metric: Any = None + + global_step: int = None + + num_total_parameters: float = None + + num_trainable_parameters: float = None + + num_buffers: float = None + + trainable_parameters_percentage: float = None + + train_dataset_info: str = None + + val_dataset_info: str = None + + train_create_time: float = None + + eval_tokens: int = None + + eval_time: float = None + + reports: Dict[str, Any] = None + + train_loss: float = None + + @property + def tuner_hyper_params(self): + hyper_params = '' + args = self.args + if 'sft_type' not in args: + return '' + if args['sft_type'] in ('lora', 'adalora', 'longlora'): + if 'lora_rank' in args: + hyper_params += f'rank={args["lora_rank"]}/' \ + f'target={args["lora_target_modules"]}/' \ + f'alpha={args["lora_alpha"]}/' \ + f'lr_ratio={args.get("lora_lr_ratio", None)}/' \ + f'use_rslora={args.get("use_rslora", False)}/' \ + f'use_dora={args.get("use_dora", False)}' + else: + hyper_params = '' + if args['sft_type'] == 'full': + if 'use_galore' in args and args['use_galore'] == 'true': + hyper_params += f'galore_rank={args["galore_rank"]}/' \ + f'galore_per_parameter={args["galore_optim_per_parameter"]}/' \ + f'galore_with_embedding={args["galore_with_embedding"]}/' + if args['sft_type'] == 'llamapro': + hyper_params += f'num_blocks={args["llamapro_num_new_blocks"]}/' + if 'neftune_noise_alpha' in args and args['neftune_noise_alpha']: + hyper_params += f'neftune_noise_alpha={args["neftune_noise_alpha"]}/' + + if hyper_params.endswith('/'): + hyper_params = hyper_params[:-1] + return hyper_params + + @property + def hyper_parameters(self): + if 'learning_rate' not in self.args: + return '' + return f'lr={self.args["learning_rate"]}/' \ + f'epoch={self.args["num_train_epochs"]}' + + @property + def train_speed(self): + if self.train_samples_per_second: + return f'{self.train_samples_per_second:.2f}({self.train_samples} samples/{self.train_time:.2f} seconds)' + else: + return '' + + @property + def infer_speed(self): + if self.eval_tokens: + return f'{self.eval_tokens / self.eval_time:.2f}({self.eval_tokens} tokens/{self.eval_time:.2f} seconds)' + return '' + + +def generate_sft_report(outputs: List[ModelOutput]): + gsm8k_accs = [] + arc_accs = [] + ceval_accs = [] + for output in outputs: + gsm8k_acc = None + arc_acc = None + ceval_acc = None + for report in (output.reports or []): + if report['name'] == 'gsm8k': + gsm8k_acc = report['score'] + if report['name'] == 'arc': + arc_acc = report['score'] + if report['name'] == 'ceval': + ceval_acc = report['score'] + gsm8k_accs.append(gsm8k_acc) + arc_accs.append(arc_acc) + ceval_accs.append(ceval_acc) + + tab = '| exp_name | model_type | dataset | ms-bench mix ratio | tuner | tuner_params | trainable params(M) | flash_attn | gradient_checkpointing | hypers | memory | train speed(samples/s) | infer speed(tokens/s) | train_loss | eval_loss | gsm8k weighted acc | arc weighted acc | ceval weighted acc |\n' \ + '| -------- | ---------- | ------- | -------------------| ----- | ------------ | ------------------- | -----------| ---------------------- | ------ | ------ | ---------------------- | --------------------- | ---------- | --------- | ------------------ | ---------------- | ------------------ |\n' # noqa + min_best_metric = 999. + min_train_loss = 999. + if outputs: + min_best_metric = min([output.best_metric or 999. for output in outputs]) + min_train_loss = min([output.train_loss or 999. for output in outputs]) + + max_gsm8k = 0.0 + if gsm8k_accs: + max_gsm8k = max([gsm8k or 0. for gsm8k in gsm8k_accs]) + + max_arc = 0.0 + if arc_accs: + max_arc = max([arc or 0. for arc in arc_accs]) + + max_ceval = 0.0 + if ceval_accs: + max_ceval = max([ceval or 0. for ceval in ceval_accs]) + + for output, gsm8k_acc, arc_acc, ceval_acc in zip(outputs, gsm8k_accs, arc_accs, ceval_accs): + use_flash_attn = output.args.get('use_flash_attn', '') + use_gc = output.args.get('gradient_checkpointing', '') + memory = output.memory + train_speed = output.train_speed + infer_speed = output.infer_speed + + is_best_metric = np.isclose(min_best_metric, output.best_metric or 999.0) + is_best_loss = np.isclose(min_train_loss, output.train_loss or 999.0) + is_best_gsm8k = np.isclose(max_gsm8k, gsm8k_acc or 0.0) + is_best_arc = np.isclose(max_arc, arc_acc or 0.0) + is_best_ceval = np.isclose(max_ceval, ceval_acc or 0.0) + + if not is_best_metric: + best_metric = '' if not output.best_metric else f'{output.best_metric:.2f}' + else: + best_metric = '' if not output.best_metric else f'**{output.best_metric:.2f}**' + + if not is_best_loss: + train_loss = '' if not output.train_loss else f'{output.train_loss:.2f}' + else: + train_loss = '' if not output.train_loss else f'**{output.train_loss:.2f}**' + + if not is_best_gsm8k: + gsm8k_acc = '' if not gsm8k_acc else f'{gsm8k_acc:.3f}' + else: + gsm8k_acc = '' if not gsm8k_acc else f'**{gsm8k_acc:.3f}**' + + if not is_best_arc: + arc_acc = '' if not arc_acc else f'{arc_acc:.3f}' + else: + arc_acc = '' if not arc_acc else f'**{arc_acc:.3f}**' + + if not is_best_ceval: + ceval_acc = '' if not ceval_acc else f'{ceval_acc:.3f}' + else: + ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**' + + line = f'|{output.name}|' \ + f'{output.args["model_type"]}|' \ + f'{output.args.get("dataset")}|' \ + f'{output.args.get("train_dataset_mix_ratio", 0.)}|' \ + f'{output.args.get("sft_type")}|' \ + f'{output.tuner_hyper_params}|' \ + f'{output.num_trainable_parameters}({output.trainable_parameters_percentage})|' \ + f'{use_flash_attn}|' \ + f'{use_gc}|' \ + f'{output.hyper_parameters}|' \ + f'{memory}|' \ + f'{train_speed}|' \ + f'{infer_speed}|' \ + f'{best_metric}|' \ + f'{train_loss}|' \ + f'{gsm8k_acc}|' \ + f'{arc_acc}|' \ + f'{ceval_acc}|\n' + tab += line + return tab + + +def generate_export_report(outputs: List[ModelOutput]): + tab = '| exp_name | model_type | calibration dataset | quantization method | quantization bits | infer speed(tokens/s) | gsm8k weighted acc | arc weighted acc | ceval weighted acc |\n' \ + '| -------- | ---------- | ------------------- | ------------------- | ----------------- | --------------------- | ------------------ | ---------------- | ------------------ |\n' # noqa + + gsm8k_accs = [] + arc_accs = [] + ceval_accs = [] + for output in outputs: + gsm8k_acc = None + arc_acc = None + ceval_acc = None + for report in (output.reports or []): + if report['name'] == 'gsm8k': + gsm8k_acc = report['score'] + if report['name'] == 'arc': + arc_acc = report['score'] + if report['name'] == 'ceval': + ceval_acc = report['score'] + gsm8k_accs.append(gsm8k_acc) + arc_accs.append(arc_acc) + ceval_accs.append(ceval_acc) + + max_gsm8k = 0.0 + if gsm8k_accs: + max_gsm8k = max([gsm8k or 0. for gsm8k in gsm8k_accs]) + + max_arc = 0.0 + if arc_accs: + max_arc = max([arc or 0. for arc in arc_accs]) + + max_ceval = 0.0 + if ceval_accs: + max_ceval = max([ceval or 0. for ceval in ceval_accs]) + + for output, gsm8k_acc, arc_acc, ceval_acc in zip(outputs, gsm8k_accs, arc_accs, ceval_accs): + infer_speed = output.infer_speed + is_best_gsm8k = np.isclose(max_gsm8k, gsm8k_acc or 0.0) + is_best_arc = np.isclose(max_arc, arc_acc or 0.0) + is_best_ceval = np.isclose(max_ceval, ceval_acc or 0.0) + + if not is_best_gsm8k: + gsm8k_acc = '' if not gsm8k_acc else f'{gsm8k_acc:.3f}' + else: + gsm8k_acc = '' if not gsm8k_acc else f'**{gsm8k_acc:.3f}**' + + if not is_best_arc: + arc_acc = '' if not arc_acc else f'{arc_acc:.3f}' + else: + arc_acc = '' if not arc_acc else f'**{arc_acc:.3f}**' + + if not is_best_ceval: + ceval_acc = '' if not ceval_acc else f'{ceval_acc:.3f}' + else: + ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**' + + if output.train_dataset_info: + dataset_info = f'{output.args["dataset"]}/{output.train_dataset_info}' + else: + dataset_info = f'{output.args["dataset"]}' + line = f'|{output.name}|' \ + f'{output.args["model_type"]}|' \ + f'{dataset_info}|' \ + f'{output.args["quant_method"]}|' \ + f'{output.args["quant_bits"]}|' \ + f'{infer_speed}|' \ + f'{gsm8k_acc}|' \ + f'{arc_acc}|' \ + f'{ceval_acc}|\n' + tab += line + return tab + + +def parse_output(file): + with open(file, 'r', encoding='utf-8') as f: + content = json.load(f) + + name = content['name'] + group = content['group'] + cmd = content['cmd'] + requirements = content['requirements'] + args = content['args'] + create_time = float(content.get('create_time') or 0) + content = content['record'] + if cmd == 'export': + best_model_checkpoint = content['best_model_checkpoint'] + eval_tokens = 0 + eval_time = 0.0 + eval_result = None + if 'eval_result' in content: + eval_result = content['eval_result'] + eval_tokens = eval_result['generation_info']['tokens'] + eval_time = eval_result['generation_info']['time'] + eval_result = eval_result['report'] + return ModelOutput( + group=group, + name=name, + cmd=cmd, + requirements=requirements, + args=args, + best_model_checkpoint=best_model_checkpoint, + eval_time=eval_time, + eval_tokens=eval_tokens, + reports=eval_result, + ) + else: + memory = None + train_time = None + train_samples = None + train_samples_per_second = None + last_model_checkpoint = None + best_model_checkpoint = None + best_metric = None + global_step = None + train_dataset_info = None + val_dataset_info = None + num_trainable_parameters = None + num_buffers = None + trainable_parameters_percentage = None + num_total_parameters = None + train_loss = None + if 'memory' in content: + memory = content['memory'] + memory = '/'.join(memory.values()) + if 'train_time' in content: + train_time = content['train_time']['train_runtime'] + train_samples = content['train_time']['n_train_samples'] + train_samples_per_second = content['train_time']['train_samples_per_second'] + if 'last_model_checkpoint' in content: + last_model_checkpoint = content['last_model_checkpoint'] + if 'best_model_checkpoint' in content: + best_model_checkpoint = content['best_model_checkpoint'] + if 'best_metric' in content: + best_metric = content['best_metric'] + if 'log_history' in content: + train_loss = content['log_history'][-1]['train_loss'] + if 'global_step' in content: + global_step = content['global_step'] + if 'dataset_info' in content: + train_dataset_info = content['dataset_info'].get('train_dataset') + val_dataset_info = content['dataset_info'].get('val_dataset') + if 'model_info' in content: + # model_info like: SwiftModel: 6758.4041M Params (19.9885M Trainable [0.2958%]), 16.7793M Buffers. + str_dict = split_str_parts_by(content['model_info'], [ + 'SwiftModel:', 'CausalLM:', 'Seq2SeqLM:', 'LMHeadModel:', 'M Params (', 'M Trainable [', ']), ', + 'M Buffers.' + ]) + str_dict = {c['key']: c['content'] for c in str_dict} + if 'SwiftModel:' in str_dict: + num_total_parameters = float(str_dict['SwiftModel:']) + elif 'CausalLM:' in str_dict: + num_total_parameters = float(str_dict['CausalLM:']) + elif 'Seq2SeqLM:' in str_dict: + num_total_parameters = float(str_dict['Seq2SeqLM:']) + elif 'LMHeadModel:' in str_dict: + num_total_parameters = float(str_dict['LMHeadModel:']) + num_trainable_parameters = float(str_dict['M Params (']) + num_buffers = float(str_dict[']), ']) + trainable_parameters_percentage = str_dict['M Trainable ['] + + eval_tokens = 0 + eval_time = 0.0 + eval_result = None + if 'eval_result' in content: + eval_result = content['eval_result'] + eval_tokens = eval_result['generation_info']['tokens'] + eval_time = eval_result['generation_info']['time'] + eval_result = eval_result['report'] + + return ModelOutput( + group=group, + name=name, + cmd=cmd, + requirements=requirements, + args=args, + memory=memory, + train_time=train_time, + train_samples=train_samples, + train_samples_per_second=train_samples_per_second, + last_model_checkpoint=last_model_checkpoint, + best_model_checkpoint=best_model_checkpoint, + best_metric=best_metric, + global_step=global_step, + train_dataset_info=train_dataset_info, + val_dataset_info=val_dataset_info, + train_create_time=create_time, + num_total_parameters=num_total_parameters, + num_trainable_parameters=num_trainable_parameters, + num_buffers=num_buffers, + trainable_parameters_percentage=trainable_parameters_percentage, + eval_time=eval_time, + eval_tokens=eval_tokens, + reports=eval_result, + train_loss=train_loss, + ) + + +def generate_reports(): + outputs = [] + for dirs, _, files in os.walk('./experiment'): + for file in files: + abs_file = os.path.join(dirs, file) + if not abs_file.endswith('.json') or 'ipynb' in abs_file: + continue + + outputs.append(parse_output(abs_file)) + + all_groups = set([output.group for output in outputs]) + for group in all_groups: + group_outputs = [output for output in outputs if output.group == group] + print(f'=================Printing the sft cmd result of exp {group}==================\n\n') + print(generate_sft_report([output for output in group_outputs if output.cmd in ('sft', 'eval')])) + # print(f'=================Printing the dpo result of exp {group}==================') + # print(generate_dpo_report([output for output in outputs if output.cmd == 'dpo'])) + print(f'=================Printing the export cmd result of exp {group}==================\n\n') + print(generate_export_report([output for output in group_outputs if output.cmd == 'export'])) + print('=================Printing done==================\n\n') + + +if __name__ == '__main__': + generate_reports() diff --git a/ms-swift/scripts/utils/run_dataset_info.py b/ms-swift/scripts/utils/run_dataset_info.py new file mode 100644 index 0000000000000000000000000000000000000000..76878125fc486c52563dc0879ff7394e1209018d --- /dev/null +++ b/ms-swift/scripts/utils/run_dataset_info.py @@ -0,0 +1,106 @@ +import os +import re + +import numpy as np + +from swift.llm import DATASET_MAPPING, EncodePreprocessor, get_model_tokenizer, get_template, load_dataset +from swift.utils import stat_array + +os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' + + +def get_cache_mapping(fpath): + with open(fpath, 'r', encoding='utf-8') as f: + text = f.read() + idx = text.find('| Dataset ID |') + text = text[idx:] + text_list = text.split('\n')[2:] + cache_mapping = {} # dataset_id -> (dataset_size, stat) + for text in text_list: + if not text: + continue + items = text.split('|') + key = items[1] if items[1] != '-' else items[6] + key = re.search(r'\[(.+?)\]', key).group(1) + stat = items[3:5] + if stat[0] == '-': + stat = ('huge dataset', '-') + cache_mapping[key] = stat + return cache_mapping + + +def get_dataset_id(key): + for dataset_id in key: + if dataset_id is not None: + break + return dataset_id + + +def run_dataset(key, template, cache_mapping): + ms_id, hf_id, _ = key + dataset_meta = DATASET_MAPPING[key] + tags = ', '.join(tag for tag in dataset_meta.tags) or '-' + dataset_id = ms_id or hf_id + use_hf = ms_id is None + if ms_id is not None: + ms_id = f'[{ms_id}](https://modelscope.cn/datasets/{ms_id})' + else: + ms_id = '-' + if hf_id is not None: + hf_id = f'[{hf_id}](https://huggingface.co/datasets/{hf_id})' + else: + hf_id = '-' + subsets = '
'.join(subset.name for subset in dataset_meta.subsets) + + if dataset_meta.huge_dataset: + dataset_size = 'huge dataset' + stat_str = '-' + elif dataset_id in cache_mapping: + dataset_size, stat_str = cache_mapping[dataset_id] + else: + num_proc = 4 + dataset, _ = load_dataset(f'{dataset_id}:all', strict=False, num_proc=num_proc, use_hf=use_hf) + dataset_size = len(dataset) + random_state = np.random.RandomState(42) + idx_list = random_state.choice(dataset_size, size=min(dataset_size, 100000), replace=False) + encoded_dataset = EncodePreprocessor(template)(dataset.select(idx_list), num_proc=num_proc) + + input_ids = encoded_dataset['input_ids'] + token_len = [len(tokens) for tokens in input_ids] + stat = stat_array(token_len)[0] + stat_str = f"{stat['mean']:.1f}±{stat['std']:.1f}, min={stat['min']}, max={stat['max']}" + + return f'|{ms_id}|{subsets}|{dataset_size}|{stat_str}|{tags}|{hf_id}|' + + +def write_dataset_info() -> None: + fpaths = ['docs/source/Instruction/支持的模型和数据集.md', 'docs/source_en/Instruction/Supported-models-and-datasets.md'] + cache_mapping = get_cache_mapping(fpaths[0]) + res_text_list = [] + res_text_list.append('| Dataset ID | Subset Name | Dataset Size | Statistic (token) | Tags | HF Dataset ID |') + res_text_list.append('| ---------- | ----------- | -------------| ------------------| ---- | ------------- |') + + all_keys = list(DATASET_MAPPING.keys()) + all_keys = sorted(all_keys, key=lambda x: get_dataset_id(x)) + _, tokenizer = get_model_tokenizer('Qwen/Qwen2.5-7B-Instruct', load_model=False) + template = get_template(tokenizer.model_meta.template, tokenizer) + try: + for i, key in enumerate(all_keys): + res = run_dataset(key, template, cache_mapping) + res_text_list.append(res) + print(res) + finally: + for fpath in fpaths: + with open(fpath, 'r', encoding='utf-8') as f: + text = f.read() + idx = text.find('| Dataset ID |') + + new_text = '\n'.join(res_text_list) + text = text[:idx] + new_text + '\n' + with open(fpath, 'w', encoding='utf-8') as f: + f.write(text) + print(f'数据集总数: {len(all_keys)}') + + +if __name__ == '__main__': + write_dataset_info() diff --git a/ms-swift/scripts/utils/run_template.py b/ms-swift/scripts/utils/run_template.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6b0445e47bb1d5326f171e39c8fd2b432d3934 --- /dev/null +++ b/ms-swift/scripts/utils/run_template.py @@ -0,0 +1,8 @@ +from swift.llm import TemplateType + +if __name__ == '__main__': + template_name_list = TemplateType.get_template_name_list() + tn_gen = ', '.join([tn for tn in template_name_list if 'generation' in tn]) + tn_chat = ', '.join([tn for tn in template_name_list if 'generation' not in tn]) + print(f'Text Generation: {tn_gen}') + print(f'Chat: {tn_chat}') diff --git a/ms-swift/swift/__init__.py b/ms-swift/swift/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..059e2ea13a3eba89b8346ed7421d62dcd83e057d --- /dev/null +++ b/ms-swift/swift/__init__.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from .utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .version import __version__, __release_datetime__ + from .tuners import (Adapter, AdapterConfig, AdapterModule, SwiftModel, LoRA, LoRAConfig, SWIFT_MAPPING, + AdaLoraConfig, LoftQConfig, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, PeftConfig, + PeftModel, PeftModelForCausalLM, ResTuningConfig, SideConfig, PeftModelForSeq2SeqLM, + PeftModelForSequenceClassification, PeftModelForTokenClassification, PrefixTuningConfig, + PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, get_peft_config, get_peft_model, + get_peft_model_state_dict, Prompt, PromptConfig, PromptModule, SwiftConfig, SwiftOutput, Swift, + SwiftTuners, LongLoRAConfig, LongLoRA, LongLoRAModelType, SCETuning, SCETuningConfig) + from .trainers import (EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy, + SchedulerType, ShardedDDPOption, TrainingArguments, Seq2SeqTrainingArguments, Trainer, + Seq2SeqTrainer) + from .utils import get_logger +else: + _import_structure = { + 'version': ['__release_datetime__', '__version__'], + 'tuners': [ + 'Adapter', 'AdapterConfig', 'AdapterModule', 'SwiftModel', 'LoRA', 'LoRAConfig', 'SWIFT_MAPPING', + 'LoraConfig', 'AdaLoraConfig', 'LoftQConfig', 'LoHaConfig', 'LoKrConfig', 'OFTConfig', 'PeftConfig', + 'ResTuningConfig', 'SideConfig', 'PeftModel', 'PeftModelForCausalLM', 'PeftModelForSeq2SeqLM', + 'PeftModelForSequenceClassification', 'PeftModelForTokenClassification', 'PrefixTuningConfig', + 'PromptEncoderConfig', 'PromptLearningConfig', 'PromptTuningConfig', 'get_peft_config', 'get_peft_model', + 'get_peft_model_state_dict', 'Prompt', 'PromptConfig', 'PromptModule', 'SwiftConfig', 'SwiftOutput', + 'Swift', 'SwiftTuners', 'LongLoRAConfig', 'LongLoRA', 'LongLoRAModelType', 'SCETuning', 'SCETuningConfig' + ], + 'trainers': [ + 'EvaluationStrategy', + 'FSDPOption', + 'HPSearchBackend', + 'HubStrategy', + 'IntervalStrategy', + 'SchedulerType', + 'ShardedDDPOption', + 'TrainingArguments', + 'Seq2SeqTrainingArguments', + 'Trainer', + 'Seq2SeqTrainer', + ], + 'utils': ['get_logger'] + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/ms-swift/swift/cli/__init__.py b/ms-swift/swift/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ms-swift/swift/cli/__pycache__/__init__.cpython-310.pyc b/ms-swift/swift/cli/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec0f1fe8c9aa6ed669ca58d2d879bb9944c7d26d Binary files /dev/null and b/ms-swift/swift/cli/__pycache__/__init__.cpython-310.pyc differ diff --git a/ms-swift/swift/cli/_megatron/pt.py b/ms-swift/swift/cli/_megatron/pt.py new file mode 100644 index 0000000000000000000000000000000000000000..b5df60b2cfca10dab15d9f95653c7e9eb8d1f102 --- /dev/null +++ b/ms-swift/swift/cli/_megatron/pt.py @@ -0,0 +1,4 @@ +from swift.megatron import megatron_pt_main + +if __name__ == '__main__': + megatron_pt_main() diff --git a/ms-swift/swift/cli/_megatron/sft.py b/ms-swift/swift/cli/_megatron/sft.py new file mode 100644 index 0000000000000000000000000000000000000000..1cd9b19b29bea7157bb8dffe48585e13dfd6d16e --- /dev/null +++ b/ms-swift/swift/cli/_megatron/sft.py @@ -0,0 +1,4 @@ +from swift.megatron import megatron_sft_main + +if __name__ == '__main__': + megatron_sft_main() diff --git a/ms-swift/swift/cli/app.py b/ms-swift/swift/cli/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4e79741a8643f1eaba1baa3a2b93771f0a3158 --- /dev/null +++ b/ms-swift/swift/cli/app.py @@ -0,0 +1,4 @@ +from swift.llm import app_main + +if __name__ == '__main__': + app_main() diff --git a/ms-swift/swift/cli/eval.py b/ms-swift/swift/cli/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..402305ea4ec0d8e8304b90ec7a8d7a11b24de8af --- /dev/null +++ b/ms-swift/swift/cli/eval.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.llm import eval_main + +if __name__ == '__main__': + eval_main() diff --git a/ms-swift/swift/cli/export.py b/ms-swift/swift/cli/export.py new file mode 100644 index 0000000000000000000000000000000000000000..508f1be4bfa2b69e2c5e6f27c6077230f30118d4 --- /dev/null +++ b/ms-swift/swift/cli/export.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.llm import export_main + +if __name__ == '__main__': + export_main() diff --git a/ms-swift/swift/cli/main.py b/ms-swift/swift/cli/main.py new file mode 100644 index 0000000000000000000000000000000000000000..8924b03bab39f0e5fac7c15183fd90a9ca3f6911 --- /dev/null +++ b/ms-swift/swift/cli/main.py @@ -0,0 +1,76 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import importlib.util +import os +import subprocess +import sys +from typing import Dict, List, Optional + +from swift.utils import get_logger + +logger = get_logger() + +ROUTE_MAPPING: Dict[str, str] = { + 'pt': 'swift.cli.pt', + 'sft': 'swift.cli.sft', + 'infer': 'swift.cli.infer', + 'merge-lora': 'swift.cli.merge_lora', + 'web-ui': 'swift.cli.web_ui', + 'deploy': 'swift.cli.deploy', + 'rollout': 'swift.cli.rollout', + 'rlhf': 'swift.cli.rlhf', + 'sample': 'swift.cli.sample', + 'export': 'swift.cli.export', + 'eval': 'swift.cli.eval', + 'app': 'swift.cli.app', +} + + +def use_torchrun() -> bool: + nproc_per_node = os.getenv('NPROC_PER_NODE') + nnodes = os.getenv('NNODES') + if nproc_per_node is None and nnodes is None: + return False + return True + + +def get_torchrun_args() -> Optional[List[str]]: + if not use_torchrun(): + return + torchrun_args = [] + for env_key in ['NPROC_PER_NODE', 'MASTER_PORT', 'NNODES', 'NODE_RANK', 'MASTER_ADDR']: + env_val = os.getenv(env_key) + if env_val is None: + continue + torchrun_args += [f'--{env_key.lower()}', env_val] + return torchrun_args + + +def _compat_web_ui(argv): + # [compat] + method_name = argv[0] + if method_name in {'web-ui', 'web_ui'} and ('--model' in argv or '--adapters' in argv or '--ckpt_dir' in argv): + argv[0] = 'app' + logger.warning('Please use `swift app`.') + + +def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None: + route_mapping = route_mapping or ROUTE_MAPPING + argv = sys.argv[1:] + _compat_web_ui(argv) + method_name = argv[0].replace('_', '-') + argv = argv[1:] + file_path = importlib.util.find_spec(route_mapping[method_name]).origin + torchrun_args = get_torchrun_args() + python_cmd = sys.executable + if torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}: + args = [python_cmd, file_path, *argv] + else: + args = [python_cmd, '-m', 'torch.distributed.run', *torchrun_args, file_path, *argv] + print(f"run sh: `{' '.join(args)}`", flush=True) + result = subprocess.run(args) + if result.returncode != 0: + sys.exit(result.returncode) + + +if __name__ == '__main__': + cli_main() diff --git a/ms-swift/swift/cli/pt.py b/ms-swift/swift/cli/pt.py new file mode 100644 index 0000000000000000000000000000000000000000..1ca2aabd8aef4353e4ba63710700f5e33e60b7ed --- /dev/null +++ b/ms-swift/swift/cli/pt.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.llm import pt_main + +if __name__ == '__main__': + pt_main() diff --git a/ms-swift/swift/cli/rollout.py b/ms-swift/swift/cli/rollout.py new file mode 100644 index 0000000000000000000000000000000000000000..631ac6f8d9abf2abcf35c1fcfce568c66b528959 --- /dev/null +++ b/ms-swift/swift/cli/rollout.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.llm import rollout_main + +if __name__ == '__main__': + rollout_main() diff --git a/ms-swift/swift/hub/__pycache__/__init__.cpython-310.pyc b/ms-swift/swift/hub/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..302b387b9c71be73917b22bbc60574bf490c950a Binary files /dev/null and b/ms-swift/swift/hub/__pycache__/__init__.cpython-310.pyc differ diff --git a/ms-swift/swift/hub/__pycache__/hub.cpython-310.pyc b/ms-swift/swift/hub/__pycache__/hub.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cb7b616119dfdf2113d4e7b7cd05e660ba4df71 Binary files /dev/null and b/ms-swift/swift/hub/__pycache__/hub.cpython-310.pyc differ diff --git a/ms-swift/swift/llm/__pycache__/__init__.cpython-310.pyc b/ms-swift/swift/llm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bd2e45bd91105bcbe01dff12105277d649b6b55 Binary files /dev/null and b/ms-swift/swift/llm/__pycache__/__init__.cpython-310.pyc differ diff --git a/ms-swift/swift/llm/__pycache__/data_loader.cpython-310.pyc b/ms-swift/swift/llm/__pycache__/data_loader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c577a5c56aee9dad01b2d9cdbc469d490be9ebdb Binary files /dev/null and b/ms-swift/swift/llm/__pycache__/data_loader.cpython-310.pyc differ diff --git a/ms-swift/swift/llm/app/__init__.py b/ms-swift/swift/llm/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0c28325418864086c7fbf1368120e4c522931ee --- /dev/null +++ b/ms-swift/swift/llm/app/__init__.py @@ -0,0 +1 @@ +from .app import SwiftApp, app_main diff --git a/ms-swift/swift/llm/argument/app_args.py b/ms-swift/swift/llm/argument/app_args.py new file mode 100644 index 0000000000000000000000000000000000000000..17389327de9f6084061c07f8f0d300ea0ce0f4b1 --- /dev/null +++ b/ms-swift/swift/llm/argument/app_args.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass +from typing import Literal, Optional + +from swift.utils import find_free_port, get_logger +from ..model import get_matched_model_meta +from ..template import get_template_meta +from .deploy_args import DeployArguments +from .webui_args import WebUIArguments + +logger = get_logger() + + +@dataclass +class AppArguments(WebUIArguments, DeployArguments): + base_url: Optional[str] = None + studio_title: Optional[str] = None + is_multimodal: Optional[bool] = None + + lang: Literal['en', 'zh'] = 'en' + verbose: bool = False + + def _init_torch_dtype(self) -> None: + if self.base_url: + self.model_meta = get_matched_model_meta(self.model) + return + super()._init_torch_dtype() + + def __post_init__(self): + super().__post_init__() + self.server_port = find_free_port(self.server_port) + if self.model_meta: + if self.system is None: + self.system = get_template_meta(self.model_meta.template).default_system + if self.is_multimodal is None: + self.is_multimodal = self.model_meta.is_multimodal + if self.is_multimodal is None: + self.is_multimodal = False diff --git a/ms-swift/swift/llm/data_loader.py b/ms-swift/swift/llm/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..fe20854ca68dbc43b0fd28ceb535d8e8434be7da --- /dev/null +++ b/ms-swift/swift/llm/data_loader.py @@ -0,0 +1,105 @@ +from typing import Optional + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader + + +class BatchSamplerShard: + + def __init__(self, total_samples: int, batch_size: int, shuffle: bool, drop_last: bool, data_seed: Optional[int]): + self.total_samples = total_samples // self.world_size + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + self.base_seed = data_seed or 0 + self.curr_seed = self.base_seed + + @property + def rank(self): + return dist.get_rank() if dist.is_initialized() else 0 + + @property + def world_size(self): + return dist.get_world_size() if dist.is_initialized() else 1 + + def __iter__(self): + start_idx = self.rank * self.total_samples + if self.shuffle: + generator = torch.Generator() + generator.manual_seed(self.curr_seed) + total_idx = torch.randperm(self.total_samples * self.world_size, generator=generator).tolist() + total_idx = total_idx[start_idx:start_idx + self.total_samples] + else: + total_idx = list(range(start_idx, start_idx + self.total_samples)) + + batch = [] + # Last batch if not complete will be dropped. + for idx in total_idx: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if not self.drop_last and len(batch) > 0: + yield batch + return + + def set_epoch(self, epoch: int): + self.curr_seed = self.base_seed + epoch + + def __len__(self) -> int: + if self.drop_last: + return self.total_samples // self.batch_size + else: + return (self.total_samples + self.batch_size - 1) // self.batch_size + + +class DataLoaderShard(DataLoader): + + def __init__(self, dataset, batch_sampler: BatchSamplerShard, **dataloader_params): + self.batch_sampler = batch_sampler + super().__init__(dataset, batch_sampler=self.batch_sampler, **dataloader_params) + + def set_epoch(self, epoch: int): + self.batch_sampler.set_epoch(epoch) + + +class DataLoaderDispatcher: + + def __init__(self, base_dataloader): + self.base_dataloader = base_dataloader + + @property + def rank(self): + return dist.get_rank(self.group) if dist.is_initialized() else 0 + + @property + def world_size(self): + return dist.get_world_size(self.group) if dist.is_initialized() else 1 + + @property + def group(self): + return dist.group.WORLD if dist.is_initialized() else 1 + + def _scatter_object_list(self, inputs): + if not dist.is_initialized(): + return inputs[0] + outputs = [None] + global_src_rank = dist.get_global_rank(self.group, 0) + dist.scatter_object_list(outputs, inputs, global_src_rank, group=self.group) + return outputs[0] + + def __iter__(self): + base_iter = iter(self.base_dataloader) + while True: + if self.rank == 0: + try: + data = [next(base_iter) for _ in range(self.world_size)] + except StopIteration: + data = [None] * self.world_size + data = self._scatter_object_list(data) + else: + data = self._scatter_object_list(None) + if data is None: + break + yield data diff --git a/ms-swift/swift/llm/dataset/dataset/__pycache__/mllm.cpython-310.pyc b/ms-swift/swift/llm/dataset/dataset/__pycache__/mllm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..677be59d2d6cf9db3af6f6007af419758395ca7b Binary files /dev/null and b/ms-swift/swift/llm/dataset/dataset/__pycache__/mllm.cpython-310.pyc differ diff --git a/ms-swift/swift/llm/dataset/dataset/llm.py b/ms-swift/swift/llm/dataset/dataset/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..de3efe9afcff40e35c112589e21adcaf3c6fae5a --- /dev/null +++ b/ms-swift/swift/llm/dataset/dataset/llm.py @@ -0,0 +1,856 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import ast +import re +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union + +import json +import numpy as np + +from ...template import split_str_parts_by +from ..preprocessor import (AlpacaPreprocessor, ClsGenerationPreprocessor, ClsPreprocessor, MessagesPreprocessor, + ResponsePreprocessor, RowPreprocessor, TextGenerationPreprocessor) +from ..register import DatasetMeta, SubsetDataset, register_dataset + + +class AlpacaZhPreprocessor(AlpacaPreprocessor): + + @classmethod + def concat_inst_input(cls, instruction, input_): + if input_ and input_.startswith('输入:'): + input_ = input_[3:] + return super().concat_inst_input(instruction, input_) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/alpaca-gpt4-data-zh', + hf_dataset_id='llm-wizard/alpaca-gpt4-data-zh', + preprocess_func=AlpacaZhPreprocessor(), + tags=['chat', 'general', '🔥'], + )) + + +class LongAlpacaPreprocessor(AlpacaPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + response = row['response'] + prefix_prompt = 'Answer: ' + if response and response.startswith(prefix_prompt): + response = response[len(prefix_prompt):].strip() + row['output'] = response + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/LongAlpaca-12k', + hf_dataset_id='Yukang/LongAlpaca-12k', + preprocess_func=LongAlpacaPreprocessor(), + tags=['long-sequence', 'QA'], + )) + + +class RuozhibaPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + title = row['title'] if row.get('title', None) is not None else row['content'] + abs = row['abs'] if 'abs' in row else None + if abs and abs != title: + title = title + ',' + abs + + pattern = r'\d+[\.,\s,\、](.+)' + match = re.search(pattern, title) + if match: + title = match.group(1) + if title: + return {'messages': [{'role': 'assistant', 'content': title}]} + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/ruozhiba', + subsets=['post-annual', 'title-good', 'title-norm'], + preprocess_func=RuozhibaPreprocessor(), + tags=['pretrain', '🔥'])) + + +class MathTrnPreprocessor(ResponsePreprocessor): + + def preprocess(self, row): + query = row['query'] + output = row['response'] + row = { + 'query': query, + 'response': output, + } + return super().preprocess(row) + + +register_dataset( + DatasetMeta(ms_dataset_id='AI-ModelScope/math-trn-format', preprocess_func=MathTrnPreprocessor(), tags=['math'])) + + +def _repair_ms_bench(messages: str) -> Optional[List[Dict[str, str]]]: + if isinstance(messages, str): + messages = ast.literal_eval(messages) + default_system = 'You are a helpful assistant.' + messages: List[Dict[str, str]] + if messages[0]['from'] == 'system' and messages[0]['value'] == default_system: + messages.pop(0) + # skip MOSS + for c in messages: + value = c['value'].lower() + if 'moss' in value or 'human:' in value or 'assistant:' in value or 'user:' in value: + return + return messages + + +register_dataset( + DatasetMeta( + ms_dataset_id='iic/ms_bench', + preprocess_func=MessagesPreprocessor(repair_messages=_repair_ms_bench), + tags=['chat', 'general', 'multi-round', '🔥'])) + + +def _repair_agent_messages(messages: List[Dict[str, str]], use_mini: bool) -> Optional[List[Dict[str, str]]]: + if use_mini: + pattern = r'\d\. {"plugin_name": "(.+?)"' + if messages[0]['from'] != 'system': + return + system = messages[0]['value'] + find_list = re.findall(pattern, system) + if len(set(find_list)) <= 1: + return + return messages + + +register_dataset( + DatasetMeta( + ms_dataset_id='damo/MSAgent-Bench', + subsets=[ + SubsetDataset( + preprocess_func=MessagesPreprocessor(repair_messages=partial(_repair_agent_messages, use_mini=False))), + SubsetDataset( + name='mini', + preprocess_func=MessagesPreprocessor(repair_messages=partial(_repair_agent_messages, use_mini=True)), + is_weak_subset=True) + ], + split=['train', 'validation'], + tags=['chat', 'agent', 'multi-round'])) + +advertise_gen_prompt = """Task: Generating advertisements based on keywords. +Keywords: {{QUERY}} +Advertisements:""" + +register_dataset( + DatasetMeta( + ms_dataset_id='lvjianjin/AdvertiseGen', + hf_dataset_id='shibing624/AdvertiseGen', + preprocess_func=TextGenerationPreprocessor( + prompt=advertise_gen_prompt, columns={ + 'content': 'query', + 'summary': 'response' + }), + tags=['text-generation', '🔥'], + split=['train', 'validation'], + )) + + +class FireflyPreprocessor(ResponsePreprocessor): + _firefly_kind_list = { + 'ProseGeneration', 'MRC', 'JinYongGeneration', 'TextCorrection', 'ClassicalChinese', 'BELLE', 'StoryGeneration', + 'Couplet', 'Cot', 'Dictionary', 'Translation', 'Program', 'SentimentAnalyze', 'OpenQA', 'AncientPoem', + 'TextMatching', 'NLI', 'Summary', 'KeywordRecognition', 'ProductDesc', 'LyricGeneration', 'Composition', + 'MusicComment', 'NER' + } + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if row['kind'] not in FireflyPreprocessor._firefly_kind_list: + return + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/firefly-train-1.1M', + hf_dataset_id='YeungNLP/firefly-train-1.1M', + preprocess_func=FireflyPreprocessor(), + tags=['chat', 'general'], + )) + +register_dataset( + DatasetMeta( + ms_dataset_id='modelscope/clue', + hf_dataset_id='clue', + subsets=['cmnli'], + preprocess_func=ClsGenerationPreprocessor(['neutral', 'entailment', 'contradiction'], + task='Natural Language Inference', + is_pair_seq=True), + tags=['text-generation', 'classification'], + split=['train', 'validation'], + )) + +register_dataset( + DatasetMeta( + ms_dataset_id='DAMO_NLP/jd', + subsets=[ + SubsetDataset( + 'default', + 'default', + preprocess_func=ClsGenerationPreprocessor(['negative', 'positive'], + task='Sentiment Classification', + is_pair_seq=False)), + SubsetDataset( + 'cls', + 'default', + preprocess_func=ClsPreprocessor(columns={'sentence': 'query'}), + ), + ], + tags=['text-generation', 'classification', '🔥'], + split=['train', 'validation'], + )) + + +class SyntheticText2SqlPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + sql_prompt = row['sql_prompt'] + sql_context = row['sql_context'] + sql = row['sql'] + sql_explanation = row['sql_explanation'] + query = f'Sql Table information:\n{sql_context}\n{sql_prompt}' + response = f'Let\'s think step by step:\n{sql_explanation}\nSo the final sql is:\n{sql}' + return super().preprocess({'query': query, 'response': response}) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/synthetic_text_to_sql', + hf_dataset_id='gretelai/synthetic_text_to_sql', + preprocess_func=SyntheticText2SqlPreprocessor(), + tags=['nl2sql', 'en'])) + + +def _repair_toolbench(conversations: List[Dict[str, str]]) -> List[Dict[str, str]]: + assert len(conversations) == 2 + if conversations[1]['from'] in {'caller', 'conclusion'}: + conversations[1]['from'] = 'assistant' + return conversations + + +register_dataset( + DatasetMeta( + ms_dataset_id='shenweizhou/alpha-umi-toolbench-processed-v2', + subsets=['backbone', 'caller', 'planner', 'summarizer'], + preprocess_func=MessagesPreprocessor(repair_messages=_repair_toolbench), + tags=['chat', 'agent', '🔥'], + huge_dataset=True)) + + +class BlossomMathPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + output, answer = row['output'], row['answer'] + return super().preprocess({'query': row['query'], 'response': f'{output}\n\nAnswer: {answer}'}) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/blossom-math-v2', + hf_dataset_id='Azure99/blossom-math-v2', + preprocess_func=BlossomMathPreprocessor(), + tags=['chat', 'math', '🔥'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/sql-create-context', + hf_dataset_id='b-mc2/sql-create-context', + preprocess_func=AlpacaPreprocessor(columns={ + 'question': 'instruction', + 'context': 'input', + 'answer': 'output' + }), + tags=['chat', 'sql', '🔥'])) + + +class TigerBotLawPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + prompt = """{type} +{title} +""" + cur_prompt = prompt.format(type=row['type'], title=row['title']) + for i in range(1, 4): + chapter = row[f'chapter{i}'] + if chapter is not None: + cur_prompt += f'{chapter}' + cur_prompt += f'{row["response"]}' + return super().preprocess({'response': cur_prompt}) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/tigerbot-law-plugin', + hf_dataset_id='TigerResearch/tigerbot-law-plugin', + preprocess_func=TigerBotLawPreprocessor(), + tags=['text-generation', 'law', 'pretrained'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='codefuse-ai/CodeExercise-Python-27k', + preprocess_func=MessagesPreprocessor(columns={'chat_rounds': 'messages'}), + tags=['chat', 'coding', '🔥'])) + + +class LeetcodePythonPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + code_with_problem = row['code_with_problem'] + idx = code_with_problem.find('```python') + problem = code_with_problem[:idx] + if problem.startswith('# '): + problem = problem[2:] + code = code_with_problem[idx:].strip() + explanation = row['explanation_only'] + return super().preprocess({'query': problem, 'response': f'{code}\n\n{explanation}'}) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/leetcode-solutions-python', + preprocess_func=LeetcodePythonPreprocessor(), + tags=['chat', 'coding', '🔥'])) + + +class StsbPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row = { + 'query': row['sentence1'], + 'response': row['sentence2'], + 'label': row['score'], + } + return super().preprocess(row) + + +class StsbGeneratePreprocessor(ResponsePreprocessor): + prompt = """Task: Based on the given two sentences, provide a similarity score between 0.0 and 1.0. +Sentence 1: {text1} +Sentence 2: {text2} +Similarity score: """ + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + return super().preprocess({ + 'query': self.prompt.format(text1=row['sentence1'], text2=row['sentence2']), + 'response': f"{row['score']:.1f}" + }) + + +class StsbRegressionPreprocessor(StsbGeneratePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + return super(StsbGeneratePreprocessor, self).preprocess({ + 'query': + self.prompt.format(text1=row['sentence1'], text2=row['sentence2']), + 'label': + row['score'] + }) + + +register_dataset( + DatasetMeta( + ms_dataset_id='sentence-transformers/stsb', + hf_dataset_id='sentence-transformers/stsb', + subsets=[ + SubsetDataset('default', preprocess_func=StsbPreprocessor()), # embedding + SubsetDataset('generate', preprocess_func=StsbGeneratePreprocessor()), + SubsetDataset('reg', preprocess_func=StsbRegressionPreprocessor()), + ], + tags=['similarity', '🔥'])) + + +def _repair_conversations_agent_instruct(s: str) -> List[Dict[str, Any]]: + s = s.replace('}\n {', '},\n {') + if isinstance(s, str): + s = ast.literal_eval(s) + return s + + +register_dataset( + DatasetMeta( + ms_dataset_id='huangjintao/AgentInstruct_copy', + subsets=['alfworld', 'db', 'kg', 'mind2web', 'os', 'webshop'], + preprocess_func=MessagesPreprocessor(repair_messages=_repair_conversations_agent_instruct), + tags=['chat', 'agent', 'multi-round'])) + + +class MultiRoleAgentPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + conv = row['conversations'] + res_prompt = '\n\n【注意事项】\n1. 这是聊天室,不要发送私信给任何人\n2. 仅代表你个人说话,不要扮演其他人,只根据对话历史进行回复\n3. 长话短说,不要说太多话,不要超过50字 ' + history_prompt = '\n\n【chat history】' + conv_prompt = '\n {name}:{content}' + query, response = '', conv[-1]['value'] + system = conv[0]['value'] if conv[0]['from'] == 'system' else '' + if conv[0]['from'] == 'user': + query = conv[0]['value'] + elif 'next_speakers:' not in system: + if '【注意事项】' not in system and system: + system += res_prompt + system += history_prompt + system += ''.join([conv_prompt.format(name=c['from'], content=c['value']) for c in conv[1:-1]]) + + if not query or not response: + return + + return { + 'messages': [{ + 'role': 'system', + 'content': system + }, { + 'role': 'user', + 'content': query + }, { + 'role': 'assistant', + 'content': response + }], + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='iic/MSAgent-MultiRole', + preprocess_func=MultiRoleAgentPreprocessor(), + tags=['chat', 'agent', 'multi-round', 'role-play', 'multi-agent'])) + +register_dataset(DatasetMeta(ms_dataset_id='swift/ToolBench', tags=['chat', 'agent', 'multi-round'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='tastelikefeet/competition_math', + subsets=[ + SubsetDataset( + name='default', + subset='default', + split=['train', 'test'], + ), + ], + tags=['qa', 'math'])) + +register_dataset(DatasetMeta(ms_dataset_id='modelscope/gsm8k', subsets=['main'], split=['train'], tags=['qa', 'math'])) + +register_dataset( + DatasetMeta(ms_dataset_id='modelscope/MathR', subsets=['default', 'clean'], split=['train'], tags=['qa', 'math'])) + +register_dataset( + DatasetMeta(ms_dataset_id='modelscope/MathR-32B-Distill', subsets=['data'], split=['train'], tags=['qa', 'math'])) + + +class CoundownTaskPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + numbers = row['nums'] + target = row.pop('response', None) + query = (f'Using the numbers {numbers}, create an equation that equals {target}.\n' + 'You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.\n' + 'Show your work in tags. And return the final equation and answer ' + 'in tags, for example (1 + 2) / 3 * 4 = 4 .') + row.update({'target': target, 'query': query}) + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='zouxuhong/Countdown-Tasks-3to4', + subsets=['default'], + preprocess_func=CoundownTaskPreprocessor(), + tags=['math'])) + + +class HC3Preprocessor(ResponsePreprocessor): + prompt = """Classification Task: Are the following responses from a human or from ChatGPT? +Question: {question} +Answer: {answer} +Category: Human, ChatGPT +Output:""" + + def preprocess(self, row): + rows = [] + for response in ['Human', 'ChatGPT']: + query = self.prompt.format( + question=row['query'], answer=self.random_state.choice(row[f'{response.lower()}_answers'])) + rows.append(super().preprocess({'query': query, 'response': response})) + return rows + + +class HC3ClsPreprocessor(HC3Preprocessor): + + def preprocess(self, row): + rows = [] + for i, response in enumerate(['Human', 'ChatGPT']): + query = self.prompt.format( + question=row['query'], answer=self.random_state.choice(row[f'{response.lower()}_answers'])) + rows.append(ResponsePreprocessor.preprocess(self, {'query': query, 'label': i})) + return rows + + +hc3_subset_names = ['baike', 'open_qa', 'nlpcc_dbqa', 'finance', 'medicine', 'law', 'psychology'] +hc3_subsets: List[SubsetDataset] = [] +for hc3_subset_name in hc3_subset_names: + hc3_subsets.append( + SubsetDataset( + name=hc3_subset_name, + subset=hc3_subset_name, + preprocess_func=HC3Preprocessor(), + )) + hc3_subsets.append( + SubsetDataset( + name=f'{hc3_subset_name}_cls', + subset=hc3_subset_name, + preprocess_func=HC3ClsPreprocessor(), + )) + +register_dataset( + DatasetMeta( + ms_dataset_id='simpleai/HC3-Chinese', + hf_dataset_id='Hello-SimpleAI/HC3-Chinese', + subsets=hc3_subsets, + tags=['text-generation', 'classification', '🔥'])) + +hc3_subset_names = ['finance', 'medicine'] +hc3_subsets: List[SubsetDataset] = [] +for hc3_subset_name in hc3_subset_names: + hc3_subsets.append( + SubsetDataset( + name=hc3_subset_name, + subset=hc3_subset_name, + preprocess_func=HC3Preprocessor(), + )) + hc3_subsets.append( + SubsetDataset( + name=f'{hc3_subset_name}_cls', + subset=hc3_subset_name, + preprocess_func=HC3ClsPreprocessor(), + )) + +register_dataset( + DatasetMeta( + ms_dataset_id='simpleai/HC3', + hf_dataset_id='Hello-SimpleAI/HC3', + subsets=hc3_subsets, + preprocess_func=HC3Preprocessor(), + tags=['text-generation', 'classification', '🔥'])) + + +class DureaderPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + prompt = """Task: Question Generation +Context: {context} +Answer: {answer} +Question:""" + answer, context = row['text1'].split('[SEP]') + return { + 'messages': [{ + 'role': 'user', + 'content': prompt.format(context=context, answer=answer) + }, { + 'role': 'assistant', + 'content': row['text2'] + }] + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='modelscope/DuReader_robust-QG', + preprocess_func=DureaderPreprocessor(), + split=['train', 'validation', 'test'], + tags=['text-generation', '🔥'])) + + +class HHRLHFPreprocessor(RowPreprocessor): + + @staticmethod + def _to_messages(data): + messages = [] + for query, response in zip(data[::2], data[1::2]): + messages.append({'role': 'user', 'content': query}) + messages.append({'role': 'assistant', 'content': response}) + return messages + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + chosen = row['chosen'].strip() + rejected = row['rejected'].strip() + parts_chosen = [s.strip() for s in re.split('\n\nHuman:|\n\nAssistant:|\n\nHum:', chosen)] + parts_rejected = [s.strip() for s in re.split('\n\nHuman:|\n\nAssistant:|\n\nHum:', rejected)] + if parts_chosen[0].startswith('Human:'): + assert parts_rejected[0].startswith('Human:') + parts_chosen[0] = parts_chosen[0][6:].strip() + parts_rejected[0] = parts_rejected[0][6:].strip() + row['messages'] = self._to_messages(parts_chosen) + row['rejected_messages'] = self._to_messages(parts_rejected) + return row + + +# TODO meta file broken +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/hh-rlhf', + subsets=['helpful-base', 'helpful-online', 'helpful-rejection-sampled'], + preprocess_func=HHRLHFPreprocessor(), + split=['train', 'test'], + tags=['rlhf', 'dpo'], + huge_dataset=True)) + + +class XlamFunctionCallingPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + query = row['query'] + answers = row['response'] + if isinstance(answers, str): + answers = json.loads(answers) + answer = np.random.choice(answers) + name = answer['name'] + args = json.dumps(answer['arguments']) + response = f'Action: {name}\nAction Input: {args}' + row = {'query': query, 'response': response, 'solution': response, 'tools': row['tools']} + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='LLM-Research/xlam-function-calling-60k', + subsets=['dataset'], + preprocess_func=XlamFunctionCallingPreprocessor(), + tags=['agent'])) + + +class HHRLHFCNPreprocessor(MessagesPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row['messages'].append(row.pop('chosen')) + row['rejected_response'] = row['rejected']['text'] + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/hh_rlhf_cn', + subsets=['hh_rlhf', 'harmless_base_cn', 'harmless_base_en', 'helpful_base_cn', 'helpful_base_en'], + preprocess_func=HHRLHFCNPreprocessor(columns={'context': 'messages'}, content_key='text'), + split=['train', 'test'], + tags=['rlhf', 'dpo', '🔥'])) + + +def repair_conversations(s: Union[str, Any]) -> Any: + if isinstance(s, str): + s = s.replace('}\n {', '},{') + s = s.replace('}\n{', '},{') + s = s.replace('}{', '},{') + s = s.replace('}\n {', '},{') + return ast.literal_eval(s) + return s + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/lmsys-chat-1m', + hf_dataset_id='lmsys/lmsys-chat-1m', + preprocess_func=MessagesPreprocessor(repair_messages=repair_conversations), + tags=['chat', 'em'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='hjh0119/shareAI-Llama3-DPO-zh-en-emoji', + hf_dataset_id='shareAI/DPO-zh-en-emoji', + preprocess_func=ResponsePreprocessor(columns={ + 'answer_zh': 'response', + 'answer_en': 'rejected_response' + }), + tags=['rlhf', 'dpo'])) + +register_dataset( + DatasetMeta(ms_dataset_id='AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto', tags=['rlhf', 'kto'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='OmniData/Zhihu-KOL-More-Than-100-Upvotes', + hf_dataset_id='bzb2023/Zhihu-KOL-More-Than-100-Upvotes', + tags=['zhihu', 'qa'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='OmniData/Zhihu-KOL', + hf_dataset_id='wangrui6/Zhihu-KOL', + huge_dataset=True, + tags=['zhihu', 'qa'], + )) + + +class GuanacoPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + instruction = row['instruction'] + input = row['input'] + output = row['output'] + history = [] + if instruction: + parts = split_str_parts_by( + instruction, ['User:', 'User:', 'Assistant:', 'Assistant:', 'Asssistent:', 'Assistent:', 'Assistenz:']) + for idx, part in enumerate(parts): + if idx % 2 == 0: + if 'user' not in part['key'].lower(): + return + history.append([part['content'], None]) + else: + if 'assist' not in part['key'].lower() and 'asssist' not in part['key'].lower(): + return + history[-1][-1] = part['content'] + if input.startswith('User:'): + input = input[len('User:'):].strip() + if any([not h[0] or not h[1] for h in history]): + return + + messages = [] + for h in history: + messages.append({'role': 'user', 'content': h[0]}) + messages.append({'role': 'assistant', 'content': h[1]}) + messages.append({'role': 'user', 'content': input}) + messages.append({'role': 'assistant', 'content': output}) + return { + 'messages': messages, + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/GuanacoDataset', + hf_dataset_id='JosephusCheung/GuanacoDataset', + preprocess_func=GuanacoPreprocessor(), + tags=['chat', 'zh'])) + + +class FunctionCallChatmlPreprocessor(MessagesPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + res = super().preprocess(row) + + if res['function_description']: + res['tools'] = res['function_description'].split('\n\n') + messages = res['messages'] + if messages[0]['role'] == 'system': + messages.pop(0) + return res + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/function-calling-chatml', + hf_dataset_id='Locutusque/function-calling-chatml', + preprocess_func=FunctionCallChatmlPreprocessor(), + tags=['agent', 'en', 'sft', '🔥'])) + + +class Dolly15kPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + instruction = row['instruction'] + context = row['context'] + response = row['response'] + query = '' + if context: + query = 'Here gives some useful information:\n' + query += context + query += '\n' + query += instruction + return { + 'messages': [{ + 'role': 'user', + 'content': query + }, { + 'role': 'assistant', + 'content': response + }], + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/databricks-dolly-15k', + hf_dataset_id='databricks/databricks-dolly-15k', + preprocess_func=Dolly15kPreprocessor(), + tags=['multi-task', 'en', 'quality'])) + + +class OrpoDPOMix40kPreprocessor(MessagesPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if row['source'] == 'toxic-dpo-v0.2': + return + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/orpo-dpo-mix-40k', + hf_dataset_id='mlabonne/orpo-dpo-mix-40k', + preprocess_func=OrpoDPOMix40kPreprocessor(columns={ + 'chosen': 'messages', + 'rejected': 'rejected_messages' + }), + tags=['dpo', 'orpo', 'en', 'quality'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/sharegpt', + subsets=['common-zh', 'unknow-zh', 'common-en'], + tags=['chat', 'general', 'multi-round'])) + + +class SelfCognitionPreprocessor(ResponsePreprocessor): + name: Optional[Tuple[str, str]] = None + author: Optional[Tuple[str, str]] = None + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + for key in ['name', 'author']: + val = getattr(self, key) + if val is None: + continue + val = val[0] if row['tag'] == 'zh' else val[1] + if val is None: + continue + placeholder = '{{' + key.upper() + '}}' + row['query'] = row['query'].replace(placeholder, val) + row['response'] = row['response'].replace(placeholder, val) + return super().preprocess(row) + + +class Qwen3SelfCognitionPreprocessor(SelfCognitionPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row['query'] = row['query'] + ' /no_think' + row['response'] = '\n\n\n\n' + row['response'] + return super().preprocess(row) + + +class EmptyThinkSelfCognitionPreprocessor(SelfCognitionPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row['response'] = '\n\n\n\n' + row['response'] + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/self-cognition', + hf_dataset_id='modelscope/self-cognition', + subsets=[ + SubsetDataset(preprocess_func=SelfCognitionPreprocessor()), + SubsetDataset('qwen3', preprocess_func=Qwen3SelfCognitionPreprocessor()), + SubsetDataset('empty_think', preprocess_func=EmptyThinkSelfCognitionPreprocessor()), + ], + tags=['chat', 'self-cognition', '🔥'])) diff --git a/ms-swift/swift/llm/dataset/preprocessor/__pycache__/__init__.cpython-310.pyc b/ms-swift/swift/llm/dataset/preprocessor/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d015f01a86197b316216c9e75e96e4099beaedd Binary files /dev/null and b/ms-swift/swift/llm/dataset/preprocessor/__pycache__/__init__.cpython-310.pyc differ diff --git a/ms-swift/swift/llm/dataset/preprocessor/__pycache__/core.cpython-310.pyc b/ms-swift/swift/llm/dataset/preprocessor/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3984aaa7b612c424f9a4651b0818179795c3072 Binary files /dev/null and b/ms-swift/swift/llm/dataset/preprocessor/__pycache__/core.cpython-310.pyc differ diff --git a/ms-swift/swift/llm/dataset/register.py b/ms-swift/swift/llm/dataset/register.py new file mode 100644 index 0000000000000000000000000000000000000000..2fad80df0b7fc4a43fafc7c4d83230a7f72826d5 --- /dev/null +++ b/ms-swift/swift/llm/dataset/register.py @@ -0,0 +1,177 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import json + +from swift.utils import get_logger, use_hf_hub +from .preprocessor import DATASET_TYPE, AutoPreprocessor, MessagesPreprocessor + +PreprocessFunc = Callable[..., DATASET_TYPE] +LoadFunction = Callable[..., DATASET_TYPE] +logger = get_logger() + + +@dataclass +class SubsetDataset: + # `Name` is used for matching subsets of the dataset, and `subset` refers to the subset_name on the hub. + name: Optional[str] = None + # If set to None, then subset is set to subset_name. + subset: str = 'default' + + # Higher priority. If set to None, the attributes of the DatasetMeta will be used. + split: Optional[List[str]] = None + preprocess_func: Optional[PreprocessFunc] = None + + # If the dataset specifies "all," weak subsets will be skipped. + is_weak_subset: bool = False + + def __post_init__(self): + if self.name is None: + self.name = self.subset + + def set_default(self, dataset_meta: 'DatasetMeta') -> 'SubsetDataset': + subset_dataset = deepcopy(self) + for k in ['split', 'preprocess_func']: + v = getattr(subset_dataset, k) + if v is None: + setattr(subset_dataset, k, deepcopy(getattr(dataset_meta, k))) + return subset_dataset + + +@dataclass +class DatasetMeta: + ms_dataset_id: Optional[str] = None + hf_dataset_id: Optional[str] = None + dataset_path: Optional[str] = None + dataset_name: Optional[str] = None + ms_revision: Optional[str] = None + hf_revision: Optional[str] = None + + subsets: List[Union[SubsetDataset, str]] = field(default_factory=lambda: ['default']) + # Applicable to all subsets. + split: List[str] = field(default_factory=lambda: ['train']) + # First perform column mapping, then proceed with the preprocess_func. + preprocess_func: PreprocessFunc = field(default_factory=lambda: AutoPreprocessor()) + load_function: Optional[LoadFunction] = None + + tags: List[str] = field(default_factory=list) + help: Optional[str] = None + huge_dataset: bool = False + + def __post_init__(self): + from .loader import DatasetLoader + if self.load_function is None: + self.load_function = DatasetLoader.load + for i, subset in enumerate(self.subsets): + if isinstance(subset, str): + self.subsets[i] = SubsetDataset(subset=subset) + + +DATASET_MAPPING: Dict[Tuple[str, str, str], DatasetMeta] = {} + + +def get_dataset_list(): + datasets = [] + for key in DATASET_MAPPING: + if use_hf_hub(): + if key[1]: + datasets.append(key[1]) + else: + if key[0]: + datasets.append(key[0]) + return datasets + + +def register_dataset(dataset_meta: DatasetMeta, *, exist_ok: bool = False) -> None: + """Register dataset + + Args: + dataset_meta: The `DatasetMeta` info of the dataset. + exist_ok: If the dataset id exists, raise error or update it. + """ + if dataset_meta.dataset_name: + dataset_name = dataset_meta.dataset_name + else: + dataset_name = dataset_meta.ms_dataset_id, dataset_meta.hf_dataset_id, dataset_meta.dataset_path + if not exist_ok and dataset_name in DATASET_MAPPING: + raise ValueError(f'The `{dataset_name}` has already been registered in the DATASET_MAPPING.') + + DATASET_MAPPING[dataset_name] = dataset_meta + + +def _preprocess_d_info(d_info: Dict[str, Any], *, base_dir: Optional[str] = None) -> Dict[str, Any]: + d_info = deepcopy(d_info) + + columns = None + if 'columns' in d_info: + columns = d_info.pop('columns') + + if 'messages' in d_info: + d_info['preprocess_func'] = MessagesPreprocessor(**d_info.pop('messages'), columns=columns) + else: + d_info['preprocess_func'] = AutoPreprocessor(columns=columns) + + if 'dataset_path' in d_info: + dataset_path = d_info.pop('dataset_path') + if base_dir is not None and not os.path.isabs(dataset_path): + dataset_path = os.path.join(base_dir, dataset_path) + dataset_path = os.path.abspath(os.path.expanduser(dataset_path)) + + d_info['dataset_path'] = dataset_path + + if 'subsets' in d_info: + subsets = d_info.pop('subsets') + for i, subset in enumerate(subsets): + if isinstance(subset, dict): + subsets[i] = SubsetDataset(**_preprocess_d_info(subset)) + d_info['subsets'] = subsets + return d_info + + +def _register_d_info(d_info: Dict[str, Any], *, base_dir: Optional[str] = None) -> DatasetMeta: + """Register a single dataset to dataset mapping + + Args: + d_info: The dataset info + """ + d_info = _preprocess_d_info(d_info, base_dir=base_dir) + dataset_meta = DatasetMeta(**d_info) + register_dataset(dataset_meta) + return dataset_meta + + +def register_dataset_info(dataset_info: Union[str, List[str], None] = None) -> List[DatasetMeta]: + """Register dataset from the `dataset_info.json` or a custom dataset info file + This is used to deal with the datasets defined in the json info file. + + Args: + dataset_info: The dataset info path + """ + # dataset_info_path: path, json or None + if dataset_info is None: + dataset_info = os.path.join(os.path.dirname(__file__), 'data', 'dataset_info.json') + assert isinstance(dataset_info, (str, list)) + base_dir = None + log_msg = None + if isinstance(dataset_info, str): + dataset_path = os.path.abspath(os.path.expanduser(dataset_info)) + if os.path.isfile(dataset_path): + log_msg = dataset_path + base_dir = os.path.dirname(dataset_path) + with open(dataset_path, 'r', encoding='utf-8') as f: + dataset_info = json.load(f) + else: + dataset_info = json.loads(dataset_info) # json + if len(dataset_info) == 0: + return [] + res = [] + for d_info in dataset_info: + res.append(_register_d_info(d_info, base_dir=base_dir)) + + if log_msg is None: + log_msg = dataset_info if len(dataset_info) < 5 else list(dataset_info.keys()) + logger.info(f'Successfully registered `{log_msg}`.') + return res diff --git a/ms-swift/swift/llm/ds_config/zero0.json b/ms-swift/swift/llm/ds_config/zero0.json new file mode 100644 index 0000000000000000000000000000000000000000..d22498c67f37cea4dc3a2a25fe4bd6c63802f657 --- /dev/null +++ b/ms-swift/swift/llm/ds_config/zero0.json @@ -0,0 +1,31 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + "bf16": { + "enabled": "auto" + }, + + "zero_optimization": { + "stage": 0, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": false, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/ms-swift/swift/llm/ds_config/zero1.json b/ms-swift/swift/llm/ds_config/zero1.json new file mode 100644 index 0000000000000000000000000000000000000000..55653e7055c2c33692498a1fc4aa8268857ef89e --- /dev/null +++ b/ms-swift/swift/llm/ds_config/zero1.json @@ -0,0 +1,35 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + "bf16": { + "enabled": "auto" + }, + + "zero_optimization": { + "stage": 1, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": false, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/ms-swift/swift/llm/eval/eval.py b/ms-swift/swift/llm/eval/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..a7fa38c3595e5bc032433e242093bcd33cdf924a --- /dev/null +++ b/ms-swift/swift/llm/eval/eval.py @@ -0,0 +1,156 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from contextlib import nullcontext +from typing import List, Union + +from evalscope.constants import EvalBackend, EvalType +from evalscope.run import TaskConfig, run_task +from evalscope.summarizer import Summarizer + +from swift.utils import append_to_jsonl, get_logger +from .. import MediaResource +from ..argument import EvalArguments +from ..base import SwiftPipeline +from ..infer import run_deploy + +logger = get_logger() + + +class SwiftEval(SwiftPipeline): + args_class = EvalArguments + args: args_class + + def run(self): + args = self.args + eval_report = {} + deploy_context = nullcontext() if args.eval_url else run_deploy(args, return_url=True) + with deploy_context as base_url: + base_url = args.eval_url or base_url + url = f"{base_url.rstrip('/')}/chat/completions" + + task_cfg = self.get_task_cfg(args.eval_dataset, args.eval_backend, url) + result = self.get_task_result(task_cfg) + eval_report[args.eval_backend] = result + + eval_report.update({ + 'time': args.time, + 'model': args.model, + 'adapters': args.adapters, + 'result_path': args.result_path, + 'eval_output_dir': args.eval_output_dir, + 'eval_limit': args.eval_limit + }) + + if args.result_jsonl: + append_to_jsonl(args.result_jsonl, eval_report) + logger.info(f'The eval result have been saved to result_jsonl: `{args.result_jsonl}`.') + return eval_report + + def get_task_result(self, task_cfg: TaskConfig): + run_task(task_cfg=task_cfg) + reports = Summarizer.get_report_from_cfg(task_cfg=task_cfg) + result = {} + if task_cfg.eval_backend == EvalBackend.OPEN_COMPASS: + for report in reports: + if report[self.args.model_suffix] != '-': + result[report['dataset']] = {report['metric']: report[self.args.model_suffix]} + elif task_cfg.eval_backend == EvalBackend.VLM_EVAL_KIT: + for report in reports: + splited_key = next(iter(report)).rsplit('_', 2) + if len(splited_key) == 3: + _, dataset, metric = splited_key + else: + dataset, metric = '-', '-' + result[dataset] = {metric: list(report.values())[0]} + else: + result = reports + return result + + def get_task_cfg(self, dataset: List[str], eval_backend: str, url: str): + assert eval_backend in {EvalBackend.NATIVE, EvalBackend.OPEN_COMPASS, EvalBackend.VLM_EVAL_KIT} + if eval_backend == EvalBackend.OPEN_COMPASS: + if self.args.local_dataset: + if os.path.exists('data'): + if not os.path.exists(os.path.join('data', 'CMB')): + raise RuntimeError('Opencompass need a `data` folder in your work dir(' + 'which will be created automatically by swift eval), ' + 'but a local path named `data` already exists, ' + 'please consider moving the dir to another location.') + else: + local_dir = MediaResource.download( + 'https://modelscope.cn/datasets/' + 'opencompass/OpenCompassDataComplete/' + 'resolve/master/OpenCompassData-complete-20240207.zip', 'OpenCompassData') + os.symlink(os.path.join(local_dir, 'data'), 'data') + + task_cfg = self.get_opencompass_task_cfg(dataset, url) + elif eval_backend == EvalBackend.VLM_EVAL_KIT: + task_cfg = self.get_vlmeval_task_cfg(dataset, url) + else: + task_cfg = self.get_native_task_cfg(dataset, url) + return task_cfg + + def get_native_task_cfg(self, dataset: List[str], url: str): + args = self.args + work_dir = os.path.join(args.eval_output_dir, 'native') + return TaskConfig( + model=args.model_suffix, + eval_type=EvalType.SERVICE, + api_url=url, + api_key=args.api_key or 'EMPTY', + datasets=dataset, + work_dir=work_dir, + limit=args.eval_limit, + eval_batch_size=args.eval_num_proc, + dataset_args=args.dataset_args, + generation_config=args.eval_generation_config, + **args.extra_eval_args) + + def get_opencompass_task_cfg(self, dataset: List[str], url: str): + args = self.args + work_dir = os.path.join(args.eval_output_dir, 'opencompass') + return TaskConfig( + eval_backend=EvalBackend.OPEN_COMPASS, + eval_config={ + 'datasets': + dataset, + 'batch_size': + args.eval_num_proc, + 'work_dir': + work_dir, + 'models': [{ + 'path': args.model_suffix, + 'openai_api_base': url, + 'key': args.api_key or 'EMPTY', + 'is_chat': args.use_chat_template + }], + 'limit': + args.eval_limit + }, + work_dir=work_dir) + + def get_vlmeval_task_cfg(self, dataset: List[str], url: str): + args = self.args + work_dir = os.path.join(args.eval_output_dir, 'vlmeval') + return TaskConfig( + eval_backend=EvalBackend.VLM_EVAL_KIT, + eval_config={ + 'data': + dataset, + 'model': [{ + 'type': args.model_suffix, + 'name': 'CustomAPIModel', + 'api_base': url, + 'key': args.api_key or 'EMPTY', + **args.eval_generation_config + }], + 'nproc': + args.eval_num_proc, + 'limit': + args.eval_limit + }, + work_dir=work_dir) + + +def eval_main(args: Union[List[str], EvalArguments, None] = None): + return SwiftEval(args).main() diff --git a/ms-swift/swift/llm/export/export.py b/ms-swift/swift/llm/export/export.py new file mode 100644 index 0000000000000000000000000000000000000000..d78658ba6287211b6e12dcac43081bb26dc15819 --- /dev/null +++ b/ms-swift/swift/llm/export/export.py @@ -0,0 +1,50 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List, Union + +from swift.llm import ExportArguments, SwiftPipeline +from swift.tuners import swift_to_peft_format +from swift.utils import get_logger +from .merge_lora import merge_lora +from .ollama import export_to_ollama +from .quant import quantize_model + +logger = get_logger() + + +class SwiftExport(SwiftPipeline): + args_class = ExportArguments + args: args_class + + def run(self): + args = self.args + if args.to_peft_format: + args.adapters[0] = swift_to_peft_format(args.adapters[0], args.output_dir) + if args.merge_lora: + output_dir = args.output_dir + if args.to_peft_format or args.quant_method or args.to_ollama or args.push_to_hub: + args.output_dir = None + merge_lora(args) + args.output_dir = output_dir # recover + if args.quant_method: + quantize_model(args) + elif args.to_ollama: + export_to_ollama(args) + elif args.to_mcore: + from swift.megatron import convert_hf2mcore + convert_hf2mcore(args) + elif args.to_hf: + from swift.megatron import convert_mcore2hf + convert_mcore2hf(args) + elif args.push_to_hub: + model_dir = args.adapters and args.adapters[0] or args.model_dir + assert model_dir, f'model_dir: {model_dir}' + args.hub.push_to_hub( + args.hub_model_id, + model_dir, + token=args.hub_token, + private=args.hub_private_repo, + commit_message=args.commit_message) + + +def export_main(args: Union[List[str], ExportArguments, None] = None): + return SwiftExport(args).main() diff --git a/ms-swift/swift/llm/export/ollama.py b/ms-swift/swift/llm/export/ollama.py new file mode 100644 index 0000000000000000000000000000000000000000..c706de25b1a15fc0bd74d853c44453d136c05cbe --- /dev/null +++ b/ms-swift/swift/llm/export/ollama.py @@ -0,0 +1,69 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import List + +from swift.llm import ExportArguments, PtEngine, RequestConfig, Template, prepare_model_template +from swift.utils import get_logger + +logger = get_logger() + + +def replace_and_concat(template: 'Template', template_list: List, placeholder: str, keyword: str): + final_str = '' + for t in template_list: + if isinstance(t, str): + final_str += t.replace(placeholder, keyword) + elif isinstance(t, (tuple, list)): + if isinstance(t[0], int): + final_str += template.tokenizer.decode(t) + else: + for attr in t: + if attr == 'bos_token_id': + final_str += template.tokenizer.bos_token + elif attr == 'eos_token_id': + final_str += template.tokenizer.eos_token + else: + raise ValueError(f'Unknown token: {attr}') + return final_str + + +def export_to_ollama(args: ExportArguments): + args.device_map = 'meta' # Accelerate load speed. + logger.info('Exporting to ollama:') + os.makedirs(args.output_dir, exist_ok=True) + model, template = prepare_model_template(args) + pt_engine = PtEngine.from_model_template(model, template) + logger.info(f'Using model_dir: {pt_engine.model_dir}') + template_meta = template.template_meta + with open(os.path.join(args.output_dir, 'Modelfile'), 'w', encoding='utf-8') as f: + f.write(f'FROM {pt_engine.model_dir}\n') + f.write(f'TEMPLATE """{{{{ if .System }}}}' + f'{replace_and_concat(template, template_meta.system_prefix, "{{SYSTEM}}", "{{ .System }}")}' + f'{{{{ else }}}}{replace_and_concat(template, template_meta.prefix, "", "")}' + f'{{{{ end }}}}') + f.write(f'{{{{ if .Prompt }}}}' + f'{replace_and_concat(template, template_meta.prompt, "{{QUERY}}", "{{ .Prompt }}")}' + f'{{{{ end }}}}') + f.write('{{ .Response }}') + f.write(replace_and_concat(template, template_meta.suffix, '', '') + '"""\n') + f.write(f'PARAMETER stop "{replace_and_concat(template, template_meta.suffix, "", "")}"\n') + + request_config = RequestConfig( + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + repetition_penalty=args.repetition_penalty) + generation_config = pt_engine._prepare_generation_config(request_config) + pt_engine._add_stop_words(generation_config, request_config, template.template_meta) + for stop_word in generation_config.stop_words: + f.write(f'PARAMETER stop "{stop_word}"\n') + f.write(f'PARAMETER temperature {generation_config.temperature}\n') + f.write(f'PARAMETER top_k {generation_config.top_k}\n') + f.write(f'PARAMETER top_p {generation_config.top_p}\n') + f.write(f'PARAMETER repeat_penalty {generation_config.repetition_penalty}\n') + + logger.info('Save Modelfile done, you can start ollama by:') + logger.info('> ollama serve') + logger.info('In another terminal:') + logger.info('> ollama create my-custom-model ' f'-f {os.path.join(args.output_dir, "Modelfile")}') + logger.info('> ollama run my-custom-model') diff --git a/ms-swift/swift/llm/infer/deploy.py b/ms-swift/swift/llm/infer/deploy.py new file mode 100644 index 0000000000000000000000000000000000000000..757801be6666390a994774feb2bc61031f5958c2 --- /dev/null +++ b/ms-swift/swift/llm/infer/deploy.py @@ -0,0 +1,240 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio +import inspect +import multiprocessing +import time +from contextlib import contextmanager +from dataclasses import asdict +from http import HTTPStatus +from threading import Thread +from typing import List, Optional, Union + +import json +import uvicorn +from aiohttp import ClientConnectorError +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from swift.llm import AdapterRequest, DeployArguments +from swift.llm.infer.protocol import MultiModalRequestMixin +from swift.plugin import InferStats +from swift.utils import JsonlWriter, get_logger +from .infer import SwiftInfer +from .infer_engine import InferClient +from .protocol import ChatCompletionRequest, CompletionRequest, Model, ModelList + +logger = get_logger() + + +class SwiftDeploy(SwiftInfer): + args_class = DeployArguments + args: args_class + + def _register_app(self): + self.app.get('/v1/models')(self.get_available_models) + self.app.post('/v1/chat/completions')(self.create_chat_completion) + self.app.post('/v1/completions')(self.create_completion) + + def __init__(self, args: Union[List[str], DeployArguments, None] = None) -> None: + super().__init__(args) + + self.infer_engine.strict = True + self.infer_stats = InferStats() + self.app = FastAPI(lifespan=self.lifespan) + self._register_app() + + async def _log_stats_hook(self): + while True: + await asyncio.sleep(self.args.log_interval) + self._compute_infer_stats() + self.infer_stats.reset() + + def _compute_infer_stats(self): + global_stats = self.infer_stats.compute() + for k, v in global_stats.items(): + global_stats[k] = round(v, 8) + logger.info(global_stats) + + def lifespan(self, app: FastAPI): + args = self.args + if args.log_interval > 0: + thread = Thread(target=lambda: asyncio.run(self._log_stats_hook()), daemon=True) + thread.start() + try: + yield + finally: + if args.log_interval > 0: + self._compute_infer_stats() + + def _get_model_list(self): + args = self.args + model_list = [args.served_model_name or args.model_suffix] + if args.adapter_mapping: + model_list += [name for name in args.adapter_mapping.keys()] + return model_list + + async def get_available_models(self): + model_list = self._get_model_list() + data = [Model(id=model_id, owned_by=self.args.owned_by) for model_id in model_list] + return ModelList(data=data) + + async def _check_model(self, request: ChatCompletionRequest) -> Optional[str]: + available_models = await self.get_available_models() + model_list = [model.id for model in available_models.data] + if request.model not in model_list: + return f'`{request.model}` is not in the model_list: `{model_list}`.' + + def _check_api_key(self, raw_request: Request) -> Optional[str]: + api_key = self.args.api_key + if api_key is None: + return + authorization = dict(raw_request.headers).get('authorization') + error_msg = 'API key error' + if authorization is None or not authorization.startswith('Bearer '): + return error_msg + request_api_key = authorization[7:] + if request_api_key != api_key: + return error_msg + + def _check_max_logprobs(self, request): + args = self.args + if isinstance(request.top_logprobs, int) and request.top_logprobs > args.max_logprobs: + return (f'The value of top_logprobs({request.top_logprobs}) is greater than ' + f'the server\'s max_logprobs({args.max_logprobs}).') + + @staticmethod + def create_error_response(status_code: Union[int, str, HTTPStatus], message: str) -> JSONResponse: + status_code = int(status_code) + return JSONResponse({'message': message, 'object': 'error'}, status_code) + + def _post_process(self, request_info, response, return_cmpl_response: bool = False): + args = self.args + + for i in range(len(response.choices)): + if not hasattr(response.choices[i], 'message') or not isinstance(response.choices[i].message.content, + (tuple, list)): + continue + for j, content in enumerate(response.choices[i].message.content): + if content['type'] == 'image': + b64_image = MultiModalRequestMixin.to_base64(content['image']) + response.choices[i].message.content[j]['image'] = f'data:image/jpg;base64,{b64_image}' + + is_finished = all(response.choices[i].finish_reason for i in range(len(response.choices))) + if 'stream' in response.__class__.__name__.lower(): + request_info['response'] += response.choices[0].delta.content + else: + request_info['response'] = response.choices[0].message.content + if return_cmpl_response: + response = response.to_cmpl_response() + if is_finished: + if args.log_interval > 0: + self.infer_stats.update(response) + if self.jsonl_writer: + self.jsonl_writer.append(request_info) + if self.args.verbose: + logger.info(request_info) + return response + + def _set_request_config(self, request_config) -> None: + default_request_config = self.args.get_request_config() + if default_request_config is None: + return + for key, val in asdict(request_config).items(): + default_val = getattr(default_request_config, key) + if default_val is not None and (val is None or isinstance(val, (list, tuple)) and len(val) == 0): + setattr(request_config, key, default_val) + + async def create_chat_completion(self, + request: ChatCompletionRequest, + raw_request: Request, + *, + return_cmpl_response: bool = False): + args = self.args + error_msg = (await self._check_model(request) or self._check_api_key(raw_request) + or self._check_max_logprobs(request)) + if error_msg: + return self.create_error_response(HTTPStatus.BAD_REQUEST, error_msg) + infer_kwargs = self.infer_kwargs.copy() + adapter_path = args.adapter_mapping.get(request.model) + if adapter_path: + infer_kwargs['adapter_request'] = AdapterRequest(request.model, adapter_path) + + infer_request, request_config = request.parse() + self._set_request_config(request_config) + request_info = {'response': '', 'infer_request': infer_request.to_printable()} + + def pre_infer_hook(kwargs): + request_info['generation_config'] = kwargs['generation_config'] + return kwargs + + infer_kwargs['pre_infer_hook'] = pre_infer_hook + try: + res_or_gen = await self.infer_async(infer_request, request_config, template=self.template, **infer_kwargs) + except Exception as e: + import traceback + logger.info(traceback.format_exc()) + return self.create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + if request_config.stream: + + async def _gen_wrapper(): + async for res in res_or_gen: + res = self._post_process(request_info, res, return_cmpl_response) + yield f'data: {json.dumps(asdict(res), ensure_ascii=False)}\n\n' + yield 'data: [DONE]\n\n' + + return StreamingResponse(_gen_wrapper(), media_type='text/event-stream') + else: + return self._post_process(request_info, res_or_gen, return_cmpl_response) + + async def create_completion(self, request: CompletionRequest, raw_request: Request): + chat_request = ChatCompletionRequest.from_cmpl_request(request) + return await self.create_chat_completion(chat_request, raw_request, return_cmpl_response=True) + + def run(self): + args = self.args + self.jsonl_writer = JsonlWriter(args.result_path) if args.result_path else None + logger.info(f'model_list: {self._get_model_list()}') + uvicorn.run( + self.app, + host=args.host, + port=args.port, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + log_level=args.log_level) + + +def deploy_main(args: Union[List[str], DeployArguments, None] = None) -> None: + SwiftDeploy(args).main() + + +def is_accessible(port: int): + infer_client = InferClient(port=port) + try: + infer_client.get_model_list() + except ClientConnectorError: + return False + return True + + +@contextmanager +def run_deploy(args: DeployArguments, return_url: bool = False): + if isinstance(args, DeployArguments) and args.__class__.__name__ == 'DeployArguments': + deploy_args = args + else: + args_dict = asdict(args) + parameters = inspect.signature(DeployArguments).parameters + for k in list(args_dict.keys()): + if k not in parameters or args_dict[k] is None: + args_dict.pop(k) + deploy_args = DeployArguments(**args_dict) + + mp = multiprocessing.get_context('spawn') + process = mp.Process(target=deploy_main, args=(deploy_args, )) + process.start() + try: + while not is_accessible(deploy_args.port): + time.sleep(1) + yield f'http://127.0.0.1:{deploy_args.port}/v1' if return_url else deploy_args.port + finally: + process.terminate() + logger.info('The deployment process has been terminated.') diff --git a/ms-swift/swift/llm/infer/infer.py b/ms-swift/swift/llm/infer/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..622cd039e0401f8abcdbd3c2f87fb6ad9baf77d9 --- /dev/null +++ b/ms-swift/swift/llm/infer/infer.py @@ -0,0 +1,237 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, List, Union + +import numpy as np +from datasets import Dataset as HfDataset + +from swift.llm import InferArguments, InferRequest, SwiftPipeline, load_dataset, prepare_model_template, sample_dataset +from swift.plugin import InferStats, MeanMetric, compute_rouge_bleu +from swift.utils import JsonlWriter, get_logger, is_master, read_from_jsonl +from .infer_engine import AdapterRequest, PtEngine +from .protocol import RequestConfig +from .utils import InferCliState + +logger = get_logger() + + +class SwiftInfer(SwiftPipeline): + args_class = InferArguments + args: args_class + + def __init__(self, args: Union[List[str], InferArguments, None] = None) -> None: + from swift.llm import merge_lora + super().__init__(args) + args = self.args + if args.merge_lora: + merge_lora(args, device_map='cpu') + self.infer_kwargs = {} + if args.infer_backend == 'vllm' and args.adapters: + self.infer_kwargs['adapter_request'] = AdapterRequest('_lora', args.adapters[0]) + + if args.infer_backend == 'pt': + model, self.template = prepare_model_template(args) + self.infer_engine = PtEngine.from_model_template(model, self.template, max_batch_size=args.max_batch_size) + logger.info(f'model: {self.infer_engine.model}') + else: + self.infer_engine = self.get_infer_engine(args) + self.template = args.get_template(self.processor) + self.random_state = np.random.RandomState(args.data_seed) + + def __getattr__(self, key: str): + try: + return super().__getattr__(key) + except AttributeError: + if 'infer_engine' in self.__dict__: + return getattr(self.infer_engine, key) + raise + + @staticmethod + def get_infer_engine(args: InferArguments, **kwargs): + kwargs.update({ + 'model_id_or_path': args.model, + 'model_type': args.model_type, + 'revision': args.model_revision, + 'torch_dtype': args.torch_dtype, + }) + infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend + if infer_backend == 'pt': + from .infer_engine import PtEngine + infer_engine_cls = PtEngine + kwargs.update(args.get_model_kwargs()) + if hasattr(args, 'max_batch_size'): + kwargs.update({'max_batch_size': args.max_batch_size}) + elif infer_backend == 'vllm': + from .infer_engine import VllmEngine + infer_engine_cls = VllmEngine + kwargs.update(args.get_vllm_engine_kwargs()) + else: + from .infer_engine import LmdeployEngine + infer_engine_cls = LmdeployEngine + kwargs.update(args.get_lmdeploy_engine_kwargs()) + return infer_engine_cls(**kwargs) + + def run(self) -> List[Dict[str, Any]]: + args = self.args + self.jsonl_writer = JsonlWriter(args.result_path) if args.result_path else None + if args.eval_human: + result = self.infer_cli() + else: + result = self.infer_dataset() + if args.result_path: + logger.info(f'The inference results have been saved to result_path: `{args.result_path}`.') + return result + + def infer_single(self, infer_request: Union[InferRequest, Dict[str, Any]], request_config: RequestConfig) -> str: + res_or_gen = self.infer([infer_request], + request_config, + template=self.template, + use_tqdm=False, + **self.infer_kwargs)[0] + if request_config and request_config.stream: + response = '' + for res in res_or_gen: + delta = res.choices[0].delta.content + print(delta, end='', flush=True) + response += delta + print() + else: + response = res_or_gen.choices[0].message.content + print(response) + print('-' * 50) + return response + + def infer_cli(self) -> List[Dict[str, Any]]: + args = self.args + template = self.template + request_config = args.get_request_config() + logger.info(f'request_config: {request_config}') + + logger.info('Input `exit` or `quit` to exit the conversation.') + logger.info('Input `multi-line` to switch to multi-line input mode.') + logger.info('Input `reset-system` to reset the system and clear the history.') + support_multi_round = template.template_meta.support_multi_round + if support_multi_round: + logger.info('Input `clear` to clear the history.') + else: + logger.info('The current template only supports single-round dialogues.') + + infer_state = InferCliState() + result_list = [] + while True: + if not support_multi_round: + infer_state.clear() + query = infer_state.input_text() + if query.strip().lower() in {'exit', 'quit'}: + break + query = infer_state.check_query(query) + if query is None: + continue + infer_state.add_query(query) + if args.model_meta.is_multimodal: + infer_state.input_mm_data() + if args.model_meta.is_reward or args.task_type == 'prm': + # reward model + response = infer_state.input_text() + infer_state.add_response(response) + data = infer_state.to_dict() + response = self.infer_single(data, request_config) + data = {'response': response, **data} + else: + data = infer_state.to_dict() + response = self.infer_single(data, request_config) + infer_state.add_response(response) + data['messages'].append({'role': 'assistant', 'content': response}) + data = {'response': response, **data} + result_list.append(data) + if self.jsonl_writer: + self.jsonl_writer.append(data) + + return result_list + + def _prepare_val_dataset(self) -> HfDataset: + args = self.args + dataset_kwargs = args.get_dataset_kwargs() + if len(args.val_dataset) > 0: + _, val_dataset = load_dataset( + args.val_dataset, split_dataset_ratio=1.0, shuffle=args.dataset_shuffle, **dataset_kwargs) + else: + _, val_dataset = load_dataset( + args.dataset, + split_dataset_ratio=args.split_dataset_ratio, + shuffle=args.val_dataset_shuffle, + **dataset_kwargs) + assert val_dataset is not None + val_dataset = sample_dataset(val_dataset, args.val_dataset_sample, args.dataset_shuffle, self.random_state) + return val_dataset + + def _calc_metric(self): + args = self.args + if not is_master(): + return + data_list = read_from_jsonl(self.jsonl_writer.fpath) + preds, labels = [], [] + for data in data_list: + preds.append(data['response']) + labels.append(data['labels']) + if args.metric == 'acc': + mean_metric = MeanMetric() + for pred, label in zip(preds, labels): + mean_metric.update(pred == label) + res = {'acc': mean_metric.compute()['value']} + elif args.metric == 'rouge': + res = compute_rouge_bleu(preds, labels) + logger.info(res) + + def infer_dataset(self) -> List[Dict[str, Any]]: + args = self.args + request_config = args.get_request_config() + logger.info(f'request_config: {request_config}') + + val_dataset = self._prepare_val_dataset() + logger.info(f'val_dataset: {val_dataset}') + result_list = [] + self.infer_kwargs['metrics'] = [InferStats()] + if request_config and request_config.stream: + for data in val_dataset: + labels = InferRequest.remove_response(data['messages']) + query = data['messages'][-1]['content'] + print(f'[QUERY] {query}') + if labels: + print(f'[LABELS] {labels}') + print('[RESPONSE] ', end='') + response = self.infer_single(data, request_config) + data['messages'].append({'role': 'assistant', 'content': response}) + data = {'response': response, 'labels': labels, **data} + result_list.append(data) + if self.jsonl_writer: + self.jsonl_writer.append(data) + else: + if args.rank >= 0 and args.global_world_size > 1: + val_dataset = val_dataset.shard(args.global_world_size, args.rank, contiguous=True) + val_dataset = list(val_dataset) + labels_list = [] + for data in val_dataset: + if args.task_type == 'causal_lm': + labels = InferRequest.remove_response(data['messages']) + else: + labels = data.pop('label', None) + labels_list.append(labels) + + resp_list = self.infer( + val_dataset, request_config, template=self.template, use_tqdm=True, **self.infer_kwargs) + for data, resp, labels in zip(val_dataset, resp_list, labels_list): + response = resp.choices[0].message.content + data['messages'].append({'role': 'assistant', 'content': response}) + data = {'response': response, 'labels': labels, 'logprobs': resp.choices[0].logprobs, **data} + result_list.append(data) + if self.jsonl_writer: + self.jsonl_writer.append(result_list, gather_obj=True) + metrics = self.infer_kwargs.pop('metrics') + print(f'[rank{args.rank}] {metrics[0].compute()}') + if args.metric is not None: + self._calc_metric() + return result_list + + +def infer_main(args: Union[List[str], InferArguments, None] = None): + return SwiftInfer(args).main() diff --git a/ms-swift/swift/llm/infer/infer_engine/__pycache__/__init__.cpython-310.pyc b/ms-swift/swift/llm/infer/infer_engine/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04aa5817ab7a0a127a7cb638bafab6879907ab9a Binary files /dev/null and b/ms-swift/swift/llm/infer/infer_engine/__pycache__/__init__.cpython-310.pyc differ diff --git a/ms-swift/swift/llm/infer/infer_engine/base.py b/ms-swift/swift/llm/infer/infer_engine/base.py new file mode 100644 index 0000000000000000000000000000000000000000..866f1dc8c0d4a892fc261d1351a92b266ff873fe --- /dev/null +++ b/ms-swift/swift/llm/infer/infer_engine/base.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from abc import ABC, abstractmethod +from typing import AsyncIterator, Iterator, List, Optional, Union + +from swift.llm import InferRequest +from swift.plugin import Metric +from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, RequestConfig + + +class BaseInferEngine(ABC): + + @abstractmethod + def infer(self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + metrics: Optional[List[Metric]] = None, + *, + use_tqdm: Optional[bool] = None, + **kwargs) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: + """ + This method performs inference on a list of inference requests. + + The method takes a list of inference requests and processes them according to the provided configuration. + It can optionally use tqdm for progress visualization and accept additional keyword arguments. + + Args: + infer_requests (List[InferRequest]): A list of inference requests to be processed. + request_config (Optional[RequestConfig]): Configuration for the request, if any. + metrics (Optional[List[Metric]]): A list of usage information to return. + use_tqdm (Optional[bool]): Whether to use tqdm for progress visualization. + **kwargs: Additional keyword arguments. + + Returns: + List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: + The result of the inference. + """ + pass + + @abstractmethod + async def infer_async(self, + infer_request: InferRequest, + request_config: Optional[RequestConfig] = None, + **kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: + """ + This method performs asynchronous inference on a single inference request. + + The method takes an inference request and processes it according to the provided configuration. + It can accept additional keyword arguments. + + Args: + infer_request (InferRequest): An inference request to be processed. + request_config (Optional[RequestConfig]): Configuration for the request, if any. + **kwargs: Additional keyword arguments. + + Returns: + Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: The result of + the asynchronous inference. + """ + pass diff --git a/ms-swift/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/ms-swift/swift/llm/infer/infer_engine/grpo_vllm_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..0e5b4e736381abaca27da973fe5a0279cedb03cd --- /dev/null +++ b/ms-swift/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -0,0 +1,152 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from copy import copy, deepcopy +from typing import Any, Dict, Iterator, List, Optional, Union + +import torch +from packaging import version + +from swift.llm import InferRequest, Template, VllmEngine, get_model_tokenizer +from swift.plugin import Metric +from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, RequestConfig +from .patch import patch_auto_config, patch_auto_tokenizer +from .utils import AdapterRequest, patch_vllm_memory_leak + +try: + # After setting the environment variables, import vllm. This way of writing allows lint to pass. + os.environ['VLLM_USE_V1'] = os.environ.get('VLLM_USE_V1', '0') + os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '3600' + import vllm + from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams, EngineArgs, LLM +except Exception: + raise + + +class GRPOVllmEngine(VllmEngine): + + def __init__( + self, + model_id_or_path: str, + torch_dtype: Optional[torch.dtype] = None, + *, + use_async_engine: bool = True, + model_type: Optional[str] = None, + use_hf: Optional[bool] = None, + hub_token: Optional[str] = None, + revision: Optional[str] = None, + # engine_kwargs + gpu_memory_utilization: float = 0.9, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: Optional[int] = None, + max_num_seqs: int = 256, + disable_custom_all_reduce: bool = False, + enforce_eager: bool = False, + limit_mm_per_prompt: Optional[Dict[str, Any]] = None, + device: str = 'auto', + # lora + enable_lora: bool = False, + max_loras: int = 1, + max_lora_rank: int = 16, + enable_prefix_caching: bool = False, + num_infer_workers: int = 1, + enable_sleep_mode: bool = False, + distributed_executor_backend: Optional[str] = None, + engine_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + patch_vllm_memory_leak() + self.use_async_engine = use_async_engine + self.processor = get_model_tokenizer( + model_id_or_path, + torch_dtype, + load_model=False, + download_model=True, + model_type=model_type, + use_hf=use_hf, + hub_token=hub_token, + revision=revision)[1] + self._post_init() + + self._prepare_engine_kwargs( + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + disable_custom_all_reduce=disable_custom_all_reduce, + enforce_eager=enforce_eager, + limit_mm_per_prompt=limit_mm_per_prompt, + enable_lora=enable_lora, + max_loras=max_loras, + max_lora_rank=max_lora_rank, + enable_prefix_caching=enable_prefix_caching, + device=device, + distributed_executor_backend=distributed_executor_backend, + enable_sleep_mode=enable_sleep_mode, + engine_kwargs=engine_kwargs, + ) + self._prepare_engine() + self._load_generation_config() + + def _prepare_engine(self) -> None: + with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config): + engine = LLM(**self.engine_args.__dict__) + self.engine = engine + + @property + def inner_model(self): + return self.engine.llm_engine.model_executor.driver_worker.model_runner.model + + @property + def inner_model_executor(self): + return self.engine.llm_engine.model_executor + + def infer( + self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + metrics: Optional[List[Metric]] = None, + *, + template: Optional[Template] = None, + use_tqdm: Optional[bool] = None, + adapter_request: Optional[AdapterRequest] = None, + ) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: + request_config = deepcopy(request_config or RequestConfig()) + if template is None: + template = self.default_template + template.set_mode('vllm') + batched_inputs, error_list = self._batch_encode( + infer_requests, template=template, strict=getattr(self, 'strict', True)) + self.set_default_max_tokens(request_config, batched_inputs) + + prompts = [] + for inputs in batched_inputs: + llm_inputs = {'prompt_token_ids': inputs['input_ids']} + mm_data = {} + for key in ['images', 'audios', 'videos']: + media_data = inputs.get(key) or [] + if media_data: + if version.parse(vllm.__version__) < version.parse('0.6'): + assert len(media_data) == 1, ( + f'The current version of vllm only supports single {key}. Please upgrade to vllm >= 0.6.0') + mm_data = {key.rstrip('s'): media_data[0]} + else: + mm_data = {key.rstrip('s'): media_data[0] if len(media_data) == 1 else media_data} + if mm_data: + llm_inputs['multi_modal_data'] = mm_data + prompts.append(llm_inputs) + + generation_configs = [] + seed = request_config.seed + assert seed >= 0, 'Seed is needed for GRPOVllmEngine.' + for i, _ in enumerate(prompts): + request_config = copy(request_config) + request_config.seed = seed + i + generation_config = self._prepare_generation_config(request_config) + self._add_stop_words(generation_config, request_config, template.template_meta) + generation_configs.append(generation_config) + outputs = self.engine.generate(prompts, generation_configs, use_tqdm=False) + return [ + self._create_chat_completion_response(result, template, generation_configs[0], '') for result in outputs + ] diff --git a/ms-swift/swift/llm/infer/infer_engine/infer_engine.py b/ms-swift/swift/llm/infer/infer_engine/infer_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..985caabce99bf36444928f743519afb56835585f --- /dev/null +++ b/ms-swift/swift/llm/infer/infer_engine/infer_engine.py @@ -0,0 +1,297 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import asyncio +import concurrent.futures +import os +from queue import Queue +from threading import Thread +from typing import Any, Dict, Iterator, List, Optional, Union + +from tqdm import tqdm + +from swift.llm import InferRequest, ProcessorMixin, get_template +from swift.llm.template import Template +from swift.llm.utils import get_ckpt_dir +from swift.plugin import Metric +from swift.utils import get_logger +from ..protocol import (ChatCompletionMessageToolCall, ChatCompletionResponse, ChatCompletionStreamResponse, + RequestConfig, UsageInfo) +from .base import BaseInferEngine + +logger = get_logger() + + +class InferEngine(BaseInferEngine, ProcessorMixin): + llm_max_batch_size = 1024 * 1024 + mllm_max_batch_size = 1024 + + def _post_init(self): + processor = self.processor + self.model_info = processor.model_info + self.model_meta = processor.model_meta + self.model_dir = self.model_info.model_dir + self.model_name = self.model_info.model_name + self.max_model_len = self.model_info.max_model_len + self.config = self.model_info.config + if getattr(self, 'default_template', None) is None: + ckpt_dir = get_ckpt_dir(self.model_dir, getattr(self, 'adapters', None)) + logger.info('Create the default_template for the infer_engine') + if ckpt_dir: + from swift.llm import BaseArguments + args = BaseArguments.from_pretrained(ckpt_dir) + self.default_template = args.get_template(self.processor) + else: + self.default_template = get_template(self.model_meta.template, self.processor) + + self._adapters_pool = {} + + def _get_stop_words(self, stop_words: List[Union[str, List[int], None]]) -> List[str]: + stop: List[str] = [] + for stop_word in stop_words: + if stop_word is None: + continue + elif isinstance(stop_word, list): + stop_word = self.tokenizer.decode(stop_word) + assert isinstance(stop_word, str) + if stop_word not in stop: + stop.append(stop_word) + return stop + + def async_iter_to_iter(self, async_iter, prog_bar, metrics) -> Iterator: + queue = Queue() + + async def _run_async_iter(): + try: + async for item in await async_iter: + queue.put(item) + except Exception as e: + if getattr(self, 'strict', True): + raise + queue.put(e) + else: + queue.put(None) + + thread = Thread(target=lambda: asyncio.run(_run_async_iter())) + thread.start() + pre_output = None + while True: + output = queue.get() + if output is None or isinstance(output, Exception): + prog_bar.update() + self._update_metrics(pre_output, metrics) + return + pre_output = output + yield output + + @staticmethod + async def batch_run(tasks): + return await asyncio.gather(*tasks) + + def _batch_infer_stream( + self, + tasks, + stream: bool = True, + use_tqdm: bool = True, + metrics: Optional[List[Metric]] = None + ) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: + + prog_bar = tqdm(total=len(tasks), dynamic_ncols=True, disable=not use_tqdm) + if stream: + return [self.async_iter_to_iter(task, prog_bar, metrics) for task in tasks] + else: + + async def _new_run(task): + try: + res = await task + except Exception as e: + if getattr(self, 'strict', True): + raise + res = e + prog_bar.update() + self._update_metrics(res, metrics) + return res + + new_tasks = [_new_run(task) for task in tasks] + return self.safe_asyncio_run(self.batch_run(new_tasks)) + + @staticmethod + def _get_usage_info(num_prompt_tokens: int, num_generated_tokens: int) -> UsageInfo: + return UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + @staticmethod + def _update_usage_info(origin_use_info: UsageInfo, num_generated_tokens: int) -> UsageInfo: + return UsageInfo( + prompt_tokens=origin_use_info.prompt_tokens, + completion_tokens=origin_use_info.completion_tokens + num_generated_tokens, + total_tokens=origin_use_info.total_tokens + num_generated_tokens, + ) + + @staticmethod + def _update_metrics(result, metrics: Optional[List[Metric]] = None): + if metrics is None: + return result + result_origin = result + if not isinstance(result, (list, tuple)): + result = [result] + for response in result: + if response is None or isinstance(response, Exception): + continue + for metric in metrics: + metric.update(response) + return result_origin + + def infer(self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + metrics: Optional[List[Metric]] = None, + *, + use_tqdm: Optional[bool] = None, + **kwargs) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: + if request_config is None: + request_config = RequestConfig() + tasks = [self.infer_async(infer_request, request_config, **kwargs) for infer_request in infer_requests] + if use_tqdm is None: + use_tqdm = not request_config.stream and len(infer_requests) > 1 + if request_config.stream: + return self._batch_infer_stream(tasks, True, use_tqdm, metrics) + else: + i = 0 + result = [] + max_batch_size = self.llm_max_batch_size + if hasattr(self, 'model_meta') and self.model_meta.is_multimodal: + # vllm & lmdeploy + max_batch_size = self.mllm_max_batch_size + prog_bar = tqdm( + total=len(infer_requests), dynamic_ncols=True, disable=not use_tqdm or len(tasks) <= max_batch_size) + while i < len(tasks): + tasks_samples = tasks[i:i + max_batch_size] + res = self._batch_infer_stream(tasks_samples, False, use_tqdm, metrics) + result += res + i += max_batch_size + prog_bar.update(len(tasks_samples)) + return result + + @staticmethod + def _get_toolcall(response: str, template: Template) -> Optional[List[ChatCompletionMessageToolCall]]: + try: + functions = template.agent_template.get_toolcall(response) + except Exception: + functions = None + if functions: + return [ChatCompletionMessageToolCall(function=function) for function in functions] + + @staticmethod + def _get_num_tokens(inputs: Dict[str, Any]) -> int: + if 'input_ids' in inputs: # 1d or 2d + input_ids = inputs['input_ids'] + if isinstance(input_ids, list): + return len(input_ids) + else: + return input_ids.shape[-1] + elif 'inputs_embeds' in inputs: # 2d or 3d + return inputs['inputs_embeds'].shape[-2] + raise ValueError(f'Unable to retrieve input_ids and inputs_embeds. inputs: {inputs}') + + def set_default_max_tokens(self, request_config: RequestConfig, inputs: Dict[str, Any]) -> None: + max_model_len = self.max_model_len + if isinstance(inputs, dict): + inputs = [inputs] + # The num_tokens takes the maximum value from inputs_list. + num_tokens = 0 + for inp in inputs: + num_tokens = max(num_tokens, self._get_num_tokens(inp)) + max_tokens = request_config.max_tokens + if max_model_len is None: + max_model_len = 8192 + logger.warning( + 'The current model is unable to retrieve `max_model_len`. It is set to the default value of 8192.') + max_max_tokens = max_model_len - num_tokens + if max_tokens is None: + request_config.max_tokens = max_max_tokens + elif max_max_tokens < request_config.max_tokens: + logger.warning(f'max_model_len({max_model_len}) - num_tokens({num_tokens}) < max_tokens({max_tokens}). ' + f'Setting max_tokens: {max_model_len - num_tokens}') + request_config.max_tokens = max_max_tokens + + def _get_logprobs(self, + logprobs_list: Optional[List[Dict[int, float]]], + token_ids: List[int], + top_logprobs: Optional[int] = None) -> Optional[Dict[str, Any]]: + if logprobs_list is None or len(token_ids) == 0: + return None + if len(token_ids) > 0: + logprobs_list = logprobs_list[-len(token_ids):] + res = [] + for logprobs, token_id in zip(logprobs_list, token_ids): + token = self.tokenizer.decode(token_id) + _res = {'token': token, 'logprob': logprobs[token_id], 'bytes': list(token.encode('utf8'))} + if top_logprobs is not None: + logprobs = {k: logprobs[k] for k in sorted(logprobs, key=lambda k: -logprobs[k])[:top_logprobs]} + res_top_logprobs = [] + for k, logprob in logprobs.items(): + if logprob == float('-inf'): + continue + token = self.tokenizer.decode(k) + res_top_logprobs.append({'token': token, 'logprob': logprob, 'bytes': list(token.encode('utf8'))}) + _res['top_logprobs'] = res_top_logprobs + res.append(_res) + return {'content': res} + + @staticmethod + def _get_finish_reason(max_tokens: int, num_prompt_tokens: int, is_finished: bool): + if is_finished: + if num_prompt_tokens >= max_tokens: + finish_reason = 'length' + else: + finish_reason = 'stop' + else: + finish_reason = None + return finish_reason + + @staticmethod + def thread_run(target, args=(), kwargs=None): + kwargs = kwargs or {} + + def func(target, queue, args, kwargs): + try: + queue.put(target(*args, **kwargs)) + except Exception as e: + queue.put(e) + + queue = Queue() + thread = Thread(target=func, args=(target, queue, args, kwargs)) + thread.start() + thread.join() + result = queue.get() + if isinstance(result, Exception): + raise result + return result + + @staticmethod + def safe_asyncio_run(coro): + return InferEngine.thread_run(asyncio.run, args=(coro, )) + + @staticmethod + def _batch_encode(infer_requests: List[InferRequest], template: Template, strict: bool): + max_workers = max(min(32, os.cpu_count(), len(infer_requests)), 1) + error_list = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit(template.encode, infer_request, return_template_inputs=True) + for infer_request in infer_requests + ] + concurrent.futures.wait(futures) + batched_inputs = [] + for i, future in enumerate(futures): + try: + batched_inputs.append(future.result()) + except Exception as e: + if strict: + raise + error_list.append((i, e)) + continue + return batched_inputs, error_list diff --git a/ms-swift/swift/llm/infer/infer_engine/lmdeploy_engine.py b/ms-swift/swift/llm/infer/infer_engine/lmdeploy_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..51a2d0fdccfc720acd46a944f672cb013ed3f185 --- /dev/null +++ b/ms-swift/swift/llm/infer/infer_engine/lmdeploy_engine.py @@ -0,0 +1,355 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio +import inspect +import os +import time +from contextlib import contextmanager +from copy import deepcopy +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union + +import lmdeploy +import torch +from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig, pipeline +from lmdeploy.api import autoget_backend_config +from lmdeploy.serve import async_engine +from packaging import version +from transformers import GenerationConfig + +from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer +from swift.plugin import Metric +from swift.utils import get_logger, get_seed +from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig) +from .infer_engine import InferEngine +from .patch import patch_auto_config, patch_auto_tokenizer +from .utils import InferStreamer, patch_lmdeploy + +try: + from lmdeploy import EngineGenerationConfig as LmdeployGenerationConfig +except ImportError: + # compat lmdeploy >= 0.6.* + from lmdeploy import GenerationConfig as LmdeployGenerationConfig + +logger = get_logger() + + +class LmdeployEngine(InferEngine): + + def __init__( + self, + model_id_or_path: str, + torch_dtype: Optional[torch.dtype] = None, + *, + model_type: Optional[str] = None, + use_hf: Optional[bool] = None, + hub_token: Optional[str] = None, + revision: Optional[str] = None, + # engine_kwargs + tp: int = 1, + session_len: Optional[int] = None, + cache_max_entry_count: float = 0.8, + quant_policy: int = 0, # e.g. 4, 8 + vision_batch_size: int = 1, # max_batch_size in VisionConfig + devices: Optional[List[int]] = None, + reload_weights: bool = False, + engine_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + version_7 = version.parse(lmdeploy.__version__) >= version.parse('0.7.0') + if reload_weights: + assert version_7, 'grpo or reload_weights need lmdeploy>=0.7.0' + if version_7 and tp == 1: + patch_lmdeploy(reload_weights) + self.processor = get_model_tokenizer( + model_id_or_path, + torch_dtype, + load_model=False, + download_model=True, + model_type=model_type, + use_hf=use_hf, + hub_token=hub_token, + revision=revision)[1] + self._post_init() + + if self.max_model_len is not None: + self.max_model_len -= 1 + self._prepare_engine_kwargs( + tp=tp, + session_len=session_len, + cache_max_entry_count=cache_max_entry_count, + quant_policy=quant_policy, + vision_batch_size=vision_batch_size, + devices=devices, + engine_kwargs=engine_kwargs) + + self.config.torch_dtype = torch_dtype or self.model_info.torch_dtype + + @contextmanager + def disable_deepspeed(): + from transformers import modeling_utils + modeling_utils.is_deepspeed_zero3_enabled_origin = modeling_utils.is_deepspeed_zero3_enabled + modeling_utils.is_deepspeed_zero3_enabled = lambda: False + yield + modeling_utils.is_deepspeed_zero3_enabled = modeling_utils.is_deepspeed_zero3_enabled_origin + del modeling_utils.is_deepspeed_zero3_enabled_origin + + with disable_deepspeed(): + self._prepare_engine() + self._load_generation_config() + + def _prepare_engine_kwargs(self, + tp: int = 1, + session_len: Optional[int] = None, + cache_max_entry_count: float = 0.8, + quant_policy: int = 0, + vision_batch_size: int = 1, + devices: Optional[List[int]] = None, + engine_kwargs: Optional[Dict[str, Any]] = None): + if engine_kwargs is None: + engine_kwargs = {} + engine_kwargs['tp'] = tp + engine_kwargs['session_len'] = session_len + engine_kwargs['cache_max_entry_count'] = cache_max_entry_count + engine_kwargs['quant_policy'] = quant_policy + backend_config = TurbomindEngineConfig(**engine_kwargs) + backend_config = autoget_backend_config(self.model_dir, backend_config) + if hasattr(backend_config, 'devices'): + if devices is None: + devices = [0] + backend_config.devices = devices + self.backend_config = backend_config + logger.info(f'backend_config: {backend_config}') + + pipeline_kwargs = {} + is_multimodal = self.model_meta.is_multimodal + if is_multimodal: + vision_config = VisionConfig(max_batch_size=vision_batch_size) + pipeline_kwargs['vision_config'] = vision_config + logger.info(f'vision_config: {vision_config}') + self.pipeline_kwargs = pipeline_kwargs + + @contextmanager + def _patch_pipeline(self): + _old_best_match_model = async_engine.best_match_model + + def _best_match_model(*args, **kwargs) -> Optional[str]: + return self.model_info.model_type + + async_engine.best_match_model = _best_match_model + try: + yield + finally: + async_engine.best_match_model = _old_best_match_model + + def _prepare_engine(self): + with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config), self._patch_pipeline(): + engine = pipeline(self.model_dir, backend_config=self.backend_config, **self.pipeline_kwargs) + self.engine = engine + + def _load_generation_config(self): + generation_config_path = os.path.join(self.model_dir, 'generation_config.json') + if os.path.isfile(generation_config_path): + generation_config = GenerationConfig.from_pretrained(self.model_dir) + kwargs = generation_config.to_dict() + max_new_tokens = kwargs.get('max_new_tokens') + if max_new_tokens is None: + kwargs.pop('max_new_tokens', None) + parameters = inspect.signature(LmdeployGenerationConfig).parameters + for k, v in kwargs.copy().items(): + if k not in parameters or v is None: + kwargs.pop(k) + self.generation_config = LmdeployGenerationConfig(**kwargs) + else: + self.generation_config = LmdeployGenerationConfig() + + def _get_stop_token_ids(self, stop_words: List[Union[str, List[int], None]]) -> List[int]: + stop_token_ids: List[int] = [] + for stop_word in stop_words: + if stop_word is None: + continue + if isinstance(stop_word, str): + stop_word = self.tokenizer.encode(stop_word, add_special_tokens=False) + if isinstance(stop_word, list): + if len(stop_word) != 1: + continue + else: + stop_token = stop_word[0] + elif isinstance(stop_word, int): + stop_token = stop_word + assert isinstance(stop_token, int) + if stop_token not in stop_token_ids: + stop_token_ids.append(stop_token) + return stop_token_ids + + def _add_stop_words(self, generation_config: LmdeployGenerationConfig, request_config: RequestConfig, + template_meta: TemplateMeta) -> None: + stop_words = (request_config.stop or []) + (self.generation_config.stop_words or []) + template_meta.stop_words + generation_config.stop_words = self._get_stop_token_ids(stop_words) + # compat lmdeploy >= 0.6.* + generation_config.stop_token_ids = generation_config.stop_words + + def _prepare_generation_config(self, request_config: RequestConfig) -> LmdeployGenerationConfig: + kwargs = {'max_new_tokens': request_config.max_tokens} + for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']: + new_value = getattr(request_config, key) + if new_value is None: + kwargs[key] = getattr(self.generation_config, key) + else: + kwargs[key] = new_value + if request_config.seed is None: + request_config.seed = get_seed() + kwargs['random_seed'] = request_config.seed + if request_config.temperature == 0: + kwargs['temperature'] = 1 # avoid unnecessary process + kwargs['top_k'] = 1 + + if request_config.logprobs: + kwargs['logprobs'] = 1 + if request_config.top_logprobs is not None: + kwargs['logprobs'] = max(1, request_config.top_logprobs) + + res = LmdeployGenerationConfig(**kwargs) + res.top_logprobs = request_config.top_logprobs + return res + + async def _infer_stream_async( + self, template: Template, inputs: Dict[str, Any], + generation_config: LmdeployGenerationConfig) -> AsyncIterator[ChatCompletionStreamResponse]: + session_id = time.time_ns() + kwargs = {'stream_output': True, 'gen_config': generation_config, 'sequence_start': True, 'sequence_end': True} + if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'): + async with self.engine.model_inst(session_id) as inst: + context = self.engine.safe_run(inst, session_id, **inputs, **kwargs) + else: + context = self.engine.safe_run(session_id) + + infer_streamer = InferStreamer(template) + token_idx = 0 + async with context as gen: + if version.parse(lmdeploy.__version__) < version.parse('0.6.5'): + generator = await self.engine.get_generator(False, session_id) + gen = generator.async_stream_infer(session_id=session_id, **inputs, **kwargs) + is_finished = False + while not is_finished: + try: + output = await gen.__anext__() + except StopAsyncIteration: + is_finished = True + delta_text = infer_streamer.get_printable_text(output.token_ids, is_finished) + if not delta_text and not is_finished: + continue + + logprobs = self._get_logprobs(output.logprobs, output.token_ids[token_idx:], + generation_config.top_logprobs) + token_idx = len(output.token_ids) + + usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token) + toolcall = None + if is_finished: + toolcall = self._get_toolcall(template.decode(output.token_ids), template) + finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token, + output.status.name == 'FINISH') + choices = [ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall), + finish_reason=finish_reason, + logprobs=logprobs) + ] + yield ChatCompletionStreamResponse(model=self.model_name, choices=choices, usage=usage_info) + + async def _infer_full_async(self, template: Template, inputs: Dict[str, Any], + generation_config: LmdeployGenerationConfig) -> ChatCompletionResponse: + session_id = time.time_ns() + kwargs = {'stream_output': False, 'gen_config': generation_config, 'sequence_start': True, 'sequence_end': True} + if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'): + async with self.engine.model_inst(session_id) as inst: + async with self.engine.safe_run(inst, session_id, **inputs, **kwargs) as gen: + async for output in gen: + pass + if self.engine.backend == 'pytorch': + # manually end pytorch session + await inst.async_end(session_id) + + else: + async with self.engine.safe_run(session_id): + generator = await self.engine.get_generator(False, session_id) + async for output in generator.async_stream_infer(session_id=session_id, **inputs, **kwargs): + pass + + response = template.decode(output.token_ids) + logprobs = self._get_logprobs(output.logprobs, output.token_ids, generation_config.top_logprobs) + + usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token) + toolcall = self._get_toolcall(response, template) + finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token, + output.status.name == 'FINISH') + choices = [ + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), + finish_reason=finish_reason, + logprobs=logprobs) + ] + return ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info) + + async def infer_async(self, + infer_request: InferRequest, + request_config: Optional[RequestConfig] = None, + *, + template: Optional[Template] = None, + pre_infer_hook=None, + **kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: + request_config = deepcopy(request_config or RequestConfig()) + if template is None: + template = self.default_template + + template.set_mode('lmdeploy') + + loop = asyncio.get_running_loop() + with torch.inference_mode(): + inputs = await loop.run_in_executor(None, template.encode, infer_request) + images = inputs.pop('images', None) + if images: + if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'): + messages = self.engine._convert_prompts(('', images)) + messages = await self.engine.async_convert_to_pil_images(messages) + results = await self.engine.vl_encoder.preprocess(messages) + if self.engine.backend == 'turbomind': + results = await self.engine.vl_encoder.async_infer(results) + inputs['images'] = [result['content'] for result in results if result['role'] == 'forward'][0] + await template.prepare_lmdeploy_turbomind_inputs(inputs) + else: + inputs['images'] = results[1]['content'] + await template.prepare_lmdeploy_pytorch_inputs(inputs) + else: + inputs['images'] = await self.engine.vl_encoder.async_infer(images) + await template.prepare_lmdeploy_turbomind_inputs(inputs) + + self.set_default_max_tokens(request_config, inputs) + generation_config = self._prepare_generation_config(request_config) + self._add_stop_words(generation_config, request_config, template.template_meta) + kwargs.update({'template': template, 'inputs': inputs, 'generation_config': generation_config}) + if pre_infer_hook: + kwargs = pre_infer_hook(kwargs) + if request_config.stream: + return self._infer_stream_async(**kwargs) + else: + return await self._infer_full_async(**kwargs) + + def _batch_infer_stream(self, *args, **kwargs): + if hasattr(self.engine, 'vl_encoder'): + self.engine.vl_encoder._loop_task = None + if hasattr(self.engine, 'free_insts'): + self.engine.free_insts = None + return super()._batch_infer_stream(*args, **kwargs) + + def infer( + self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + metrics: Optional[List[Metric]] = None, + *, + template: Optional[Template] = None, + use_tqdm: Optional[bool] = None, + ) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: + return super().infer(infer_requests, request_config, metrics, template=template, use_tqdm=use_tqdm) diff --git a/ms-swift/swift/llm/infer/infer_engine/pt_engine.py b/ms-swift/swift/llm/infer/infer_engine/pt_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..f4724db9a8343afb2490a7e3e5887a2d562c4b25 --- /dev/null +++ b/ms-swift/swift/llm/infer/infer_engine/pt_engine.py @@ -0,0 +1,547 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio +import hashlib +import inspect +import pickle +import time +from copy import deepcopy +from queue import Queue +from threading import Thread +from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union + +import json +import torch +from tqdm import tqdm +from transformers import GenerationConfig, LogitsProcessorList +from transformers.utils import is_torch_npu_available + +from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer, safe_snapshot_download, to_device +from swift.plugin import Metric +from swift.tuners import Swift +from swift.utils import get_logger +from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, random_uuid) +from .infer_engine import InferEngine +from .utils import AdapterRequest, InferStreamer, LogitsStreamer, TokensIteratorStreamer, prepare_generation_config + +logger = get_logger() + + +class _GenerationConfig(GenerationConfig): + + def __repr__(self) -> str: + parameters = inspect.signature(self.to_json_string).parameters + kwargs = {} + if 'ignore_metadata' in parameters: + kwargs['ignore_metadata'] = True + gen_kwargs = json.loads(self.to_json_string(**kwargs)) + gen_kwargs.pop('transformers_version', None) + return f'GenerationConfig({gen_kwargs})' + + +class PtEngine(InferEngine): + + def __init__( + self, + model_id_or_path: str, + torch_dtype: Optional[torch.dtype] = None, + *, + adapters: List[str] = None, + max_batch_size: int = 1, + # + model_type: Optional[str] = None, + use_hf: Optional[bool] = None, + revision: Optional[str] = None, + hub_token: Optional[str] = None, + load_model: bool = True, + # model kwargs + attn_impl: Literal['flash_attn', 'sdpa', 'eager', None] = None, + device_map: Optional[Union[str, Dict[str, Any]]] = None, + quantization_config=None, + model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs): + self.model, self.processor = get_model_tokenizer( + model_id_or_path, + torch_dtype, + load_model=load_model, + model_type=model_type, + download_model=True, + use_hf=use_hf, + hub_token=hub_token, + revision=revision, + device_map=device_map, + quantization_config=quantization_config, + attn_impl=attn_impl, + model_kwargs=model_kwargs, + **kwargs) + self.max_batch_size = max_batch_size + if isinstance(adapters, str): + adapters = [adapters] + self.adapters = adapters or [] + for adapter in self.adapters: + self._add_adapter(safe_snapshot_download(adapter, use_hf=use_hf, hub_token=hub_token)) + self._post_init() + + def _post_init(self): + super()._post_init() + self.engine = self.model # dummy + self.generation_config = self.model.generation_config + self._queue = Queue() + self._task_pool = {} + self._task_thread = None + + def _start_infer_worker(self): + self._task_thread = Thread(target=self._infer_worker, daemon=True) + self._task_thread.start() + + def _fetch_infer_requests(self): + while not self._queue.empty(): + infer_request, kwargs, queue = self._queue.get() + template = kwargs['template'] + info = hashlib.sha256(pickle.dumps((kwargs['request_config'], template + and template.template_meta))).hexdigest() + if info not in self._task_pool: + self._task_pool[info] = kwargs, [] + self._task_pool[info][1].append((infer_request, queue)) + if len(self._task_pool) == 0: + return + key, (kwargs, data) = next(iter(self._task_pool.items())) + max_batch_size = self.max_batch_size or len(data) + data, remain_data = data[:max_batch_size], data[max_batch_size:] + if remain_data: + self._task_pool[key] = kwargs, remain_data + else: + self._task_pool.pop(key) + kwargs = kwargs.copy() + kwargs['infer_requests'] = [d[0] for d in data] + queue_list = [d[1] for d in data] + return kwargs, queue_list + + def _infer_worker(self): + while True: + time.sleep(0.01) + item = self._fetch_infer_requests() + if item is not None: + kwargs, queue_list = item + request_config = kwargs['request_config'] + res_list_or_gen = self._infer(**kwargs) + if request_config.stream: + finished = False + while not finished: + try: + res_list = next(res_list_or_gen) + except StopIteration: + finished = True + res_list = [None] * len(queue_list) + for (queue, loop), res in zip(queue_list, res_list): + asyncio.run_coroutine_threadsafe(queue.put(res), loop) + else: + for (queue, loop), res in zip(queue_list, res_list_or_gen): + asyncio.run_coroutine_threadsafe(queue.put(res), loop) + + def _add_adapter(self, adapter_path: str, adapter_name: Optional[str] = None) -> None: + self.model = Swift.from_pretrained(self.model, adapter_path, adapter_name) + + @classmethod + def from_model_template(cls, model, template=None, *, max_batch_size: int = 1): + self = super().__new__(cls) + self.model = model + self.default_template = template + self.processor = template.processor + self.max_batch_size = max_batch_size + self._post_init() + return self + + def _prepare_generation_config(self, request_config: RequestConfig) -> _GenerationConfig: + generation_config = prepare_generation_config(self.generation_config, request_config, self.tokenizer) + generation_config.return_dict_in_generate = True + if request_config.logprobs: + generation_config.output_logits = True + generation_config.top_logprobs = request_config.top_logprobs + generation_config.num_return_sequences = request_config.n + return _GenerationConfig(**generation_config.to_dict()) + + def _add_stop_words(self, generation_config: _GenerationConfig, request_config: RequestConfig, + template_meta: TemplateMeta) -> None: + stop_words = (request_config.stop or []) + template_meta.stop_words + generation_config.stop_words = self._get_stop_words(stop_words) + + @staticmethod + def preprocess_logits(batched_logits: Optional[List[torch.Tensor]], batched_generate_ids: torch.Tensor, + top_logprobs: int): + batch_size = batched_generate_ids.shape[0] + if batched_logits is None: + return None + batched_logprobs = [] + for i in range(batch_size): + logprobs_list = [] + generate_ids = batched_generate_ids[i] + for j, logits in enumerate(batched_logits): + token = generate_ids[j].item() + logprobs = torch.log_softmax(logits[i], -1) + tokens = [token] + logprobs.argsort(descending=True, dim=-1)[:top_logprobs].tolist() + logprobs_list.append({token: logprobs[token].item() for token in tokens}) + batched_logprobs.append(logprobs_list) + return batched_logprobs + + @staticmethod + def _update_batched_logprobs(batched_logprobs: List[torch.Tensor], logits_streamer: Optional[LogitsStreamer], + generate_ids: torch.Tensor, top_logprobs: int) -> None: + seq_len = generate_ids.shape[1] - len(batched_logprobs[0]) + if logits_streamer is None or seq_len == 0: + return + + res = [] + for i in range(seq_len): + res.append(logits_streamer.queue.get()) + new_batched_logprobs = PtEngine.preprocess_logits(res, generate_ids[:, -seq_len:], top_logprobs) + for logprobs, new_logprobs in zip(batched_logprobs, new_batched_logprobs): + logprobs += new_logprobs + + def _infer_stream(self, + template: Template, + inputs: Dict[str, Any], + *, + generation_config: GenerationConfig, + adapter_request: Optional[AdapterRequest] = None, + **kwargs) -> Iterator[List[Optional[ChatCompletionStreamResponse]]]: + + if generation_config.num_beams != 1: + error_msg = 'Streaming generation does not support beam search.' + raise ValueError(error_msg) + streamer = TokensIteratorStreamer() + generate_kwargs = { + 'generation_config': generation_config, + 'streamer': streamer, + **inputs, + } + adapter_names = self._get_adapter_names(adapter_request) + if adapter_names is not None: + generate_kwargs['adapter_names'] = adapter_names + num_prompt_tokens = self._get_num_tokens(inputs) + + logits_streamer = None + if generation_config.output_logits: + generate_kwargs['logits_processor'] = LogitsProcessorList([LogitsStreamer()]) + + def _model_generate(**kwargs): + if is_torch_npu_available(): + torch.npu.set_device(self.model.device) + template.generate(self.model, **kwargs) + + generate_kwargs = template.prepare_generate_kwargs(generate_kwargs, model=self.model) + thread = Thread(target=_model_generate, kwargs=generate_kwargs) + thread.start() + batch_size = inputs['attention_mask'].shape[0] + all_is_finished = False + is_finished = [False] * batch_size + infer_streamers = [InferStreamer(template) for _ in range(batch_size)] + request_id_list = [f'chatcmpl-{random_uuid()}' for _ in range(batch_size)] + token_idxs = [0] * batch_size + + raw_batched_generate_ids = None # or torch.Tensor: [batch_size, seq_len] + batched_logprobs = [[] for _ in range(batch_size)] + while not all_is_finished: + try: + batched_tokens = next(streamer) + if batched_tokens.ndim == 1: + batched_tokens = batched_tokens[:, None] + + raw_batched_generate_ids = torch.concat( + [batched_tokens] + if raw_batched_generate_ids is None else [raw_batched_generate_ids, batched_tokens], + dim=1) + except StopIteration: + all_is_finished = True + + batched_generate_ids = template.get_generate_ids(raw_batched_generate_ids, num_prompt_tokens) + self._update_batched_logprobs(batched_logprobs, logits_streamer, batched_generate_ids, + generation_config.top_logprobs or 1) + + res = [] + for i in range(batched_generate_ids.shape[0]): + if is_finished[i]: + res.append(None) + continue + generate_ids = batched_generate_ids[i] + + # ignore pad_token + masks = generate_ids != self.tokenizer.pad_token_id + generate_ids = generate_ids[masks].tolist() + logprobs_list = None + if batched_logprobs[i]: + logprobs_list = [logprobs for m, logprobs in zip(masks, batched_logprobs[i]) if m.item()] + + is_finished[i] = ( + all_is_finished or is_finished[i] + or len(generate_ids) > 0 and generate_ids[-1] == self.tokenizer.pad_token_id) + delta_text = infer_streamers[i].get_printable_text(generate_ids, is_finished[i]) + if not delta_text and not is_finished[i]: + res.append(None) + continue + logprobs = self._get_logprobs(logprobs_list, generate_ids[token_idxs[i]:], + generation_config.top_logprobs) + token_idxs[i] = len(generate_ids) + + usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids)) + toolcall = None + if is_finished[i]: + toolcall = self._get_toolcall(template.decode(generate_ids), template) + finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens, + is_finished[i]) + + choices = [ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall), + finish_reason=finish_reason, + logprobs=logprobs) + ] + res.append( + ChatCompletionStreamResponse( + model=self.model_name, choices=choices, usage=usage_info, id=request_id_list[i])) + if any(res): + yield res + + def _get_adapter_names(self, adapter_request: Optional[AdapterRequest]) -> Optional[List[str]]: + if adapter_request is None: + if self._adapters_pool: + return ['__base__'] + return + adapter_name = adapter_request.name + if adapter_name not in self._adapters_pool: + self._adapters_pool[adapter_name] = adapter_request + self._add_adapter(adapter_request.path, adapter_name) + return [adapter_name] + + def _infer_forward(self, + template: Template, + inputs: Dict[str, Any], + adapter_request: Optional[AdapterRequest] = None, + **kwargs): + call_kwargs = {} + top_logprobs = getattr(kwargs.get('generation_config'), 'top_logprobs', None) or 20 + adapter_names = self._get_adapter_names(adapter_request) + if adapter_names is not None: + call_kwargs['adapter_names'] = adapter_names + num_prompt_tokens = self._get_num_tokens(inputs) + inputs.pop('labels', None) + logits = self.model(**inputs, **call_kwargs).logits + if template.mode == 'seq_cls': + preds, logprobs = template.decode_seq_cls(logits, top_logprobs) + elif template.mode == 'prm': + preds = template.decode_prm(inputs['input_ids'], logits) + logprobs = [None] * len(preds) + else: + raise ValueError(f'Unsupported mode: {template.mode}') + + res = [] + for i, pred in enumerate(preds): + usage_info = self._get_usage_info(num_prompt_tokens, 1) + choices = [ + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role='assistant', content=pred, tool_calls=None), + finish_reason='stop', + logprobs=logprobs[i]) + ] + res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info)) + return res + + def _infer_full(self, + template: Template, + inputs: Dict[str, Any], + *, + generation_config: GenerationConfig, + adapter_request: Optional[AdapterRequest] = None, + template_inputs=None) -> List[ChatCompletionResponse]: + # bos_token TODO: encoder-decoder + generate_kwargs = {'generation_config': generation_config, **inputs} + adapter_names = self._get_adapter_names(adapter_request) + if adapter_names is not None: + generate_kwargs['adapter_names'] = adapter_names + num_prompt_tokens = self._get_num_tokens(inputs) + generate_kwargs = template.prepare_generate_kwargs(generate_kwargs, model=self.model) + output = dict(template.generate(self.model, **generate_kwargs)) + output.pop('past_key_values', None) + batched_generate_ids = output['sequences'] + batched_generate_ids = template.get_generate_ids(batched_generate_ids, num_prompt_tokens) + template.debug_logger({'generate_ids': batched_generate_ids}) # debug + batched_logprobs = self.preprocess_logits( + output.get('logits'), batched_generate_ids, generation_config.top_logprobs) + + res = [] + num_return_sequences = generation_config.num_return_sequences + for i in range(inputs['attention_mask'].shape[0]): + choices = [] + usage_info = self._get_usage_info(num_prompt_tokens, 0) + for j in range(num_return_sequences): + batched_index = i * num_return_sequences + j + generate_ids = batched_generate_ids[batched_index] + + # ignore pad_token + masks = generate_ids != self.tokenizer.pad_token_id + generate_ids = generate_ids[masks].tolist() + logprobs_list = None + if batched_logprobs is not None: + logprobs_list = [ + logprobs for m, logprobs in zip(masks, batched_logprobs[batched_index]) if m.item() + ] + + logprobs = self._get_logprobs(logprobs_list, generate_ids, generation_config.top_logprobs) + usage_info = self._update_usage_info(usage_info, len(generate_ids)) + response = template.decode(generate_ids, template_inputs=template_inputs[i]) + finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens, True) + toolcall = self._get_toolcall(response, template) + choices.append( + ChatCompletionResponseChoice( + index=j, + message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), + finish_reason=finish_reason, + logprobs=logprobs)) + res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info)) + return res + + async def infer_async( + self, + infer_request: InferRequest, + request_config: Optional[RequestConfig] = None, + *, + template: Optional[Template] = None, + adapter_request: Optional[AdapterRequest] = None, + pre_infer_hook=None, + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: + if request_config is None: + request_config = RequestConfig() + queue = asyncio.Queue() + self._queue.put((infer_request, { + 'request_config': request_config, + 'template': template, + 'adapter_request': adapter_request, + 'pre_infer_hook': pre_infer_hook + }, (queue, asyncio.get_event_loop()))) + await asyncio.sleep(0) + if self._task_thread is None: + self._start_infer_worker() + if request_config.stream: + + async def _gen_wrapper(): + while True: + item = await queue.get() + await asyncio.sleep(0) + if item is None: + break + yield item + + return _gen_wrapper() + else: + return await queue.get() + + @staticmethod + def _add_error_list(outputs, error_list): + for i, error in error_list: + outputs.insert(i, error) + return outputs + + # Ensure `template._post_encode` has no gradient. + @torch.inference_mode() + def _infer( + self, + infer_requests: List[InferRequest], + request_config: RequestConfig, + *, + template: Optional[Template] = None, + adapter_request: Optional[AdapterRequest] = None, + pre_infer_hook=None, + ) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]: + self.model.eval() + request_config = deepcopy(request_config) + if template is None: + template = self.default_template + if template.use_model: + template.model = self.model + + generation_config = None + if self.model_info.task_type == 'causal_lm': + template.set_mode('pt') + + batched_inputs, error_list = self._batch_encode( + infer_requests, template=template, strict=getattr(self, 'strict', True)) + if len(batched_inputs) > 0: + template_inputs = [inputs.pop('template_inputs') for inputs in batched_inputs] + inputs = to_device(template.data_collator(batched_inputs), self.model.device) + template.debug_logger(inputs) # debug + if self.model.model_meta.is_multimodal: + _, inputs = template.pre_forward_hook(self.model, None, inputs) + if self.model_info.task_type == 'causal_lm': + self.set_default_max_tokens(request_config, inputs) + generation_config = self._prepare_generation_config(request_config) + self._add_stop_words(generation_config, request_config, template.template_meta) + else: + generation_config = request_config + + kwargs = { + 'template': template, + 'inputs': inputs, + 'generation_config': generation_config, + 'adapter_request': adapter_request, + 'template_inputs': template_inputs + } + if pre_infer_hook: + kwargs = pre_infer_hook(kwargs) + else: + kwargs = {} + if request_config.stream: + + def _gen_wrapper(): + if len(kwargs) > 0: + for res in self._infer_stream(**kwargs): + yield self._add_error_list(res, error_list) + else: + yield self._add_error_list([], error_list) + + return _gen_wrapper() + else: + if len(kwargs) > 0: + infer_func = self._infer_forward if template.mode in ('seq_cls', 'prm') else self._infer_full + res = infer_func(**kwargs) + else: + res = [] + return self._add_error_list(res, error_list) + + def infer( + self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + metrics: Optional[List[Metric]] = None, + *, + template: Optional[Template] = None, + use_tqdm: Optional[bool] = None, + adapter_request: Optional[AdapterRequest] = None + ) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: + if request_config is None: + request_config = RequestConfig() + if request_config.stream: + return super().infer( + infer_requests, + request_config, + metrics, + template=template, + use_tqdm=use_tqdm, + adapter_request=adapter_request) + # Has higher stability than calling super().infer + if use_tqdm is None: + use_tqdm = not request_config.stream and len(infer_requests) > 1 + prog_bar = tqdm(total=len(infer_requests), dynamic_ncols=True, disable=not use_tqdm) + # If self.max_batch_size is None or 0, then process all infer_requests at once. + max_batch_size = self.max_batch_size or len(infer_requests) + res = [] + i = 0 + while i < len(infer_requests): + infer_requests_samples = infer_requests[i:i + max_batch_size] + res += self._infer( + infer_requests_samples, request_config, template=template, adapter_request=adapter_request) + i += max_batch_size + prog_bar.update(len(infer_requests_samples)) + self._update_metrics(res, metrics) + return res diff --git a/ms-swift/swift/llm/infer/infer_engine/vllm_engine.py b/ms-swift/swift/llm/infer/infer_engine/vllm_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..4e968086d624d450f4adb96dacb47c3f127dd218 --- /dev/null +++ b/ms-swift/swift/llm/infer/infer_engine/vllm_engine.py @@ -0,0 +1,505 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio +import inspect +import os +from contextlib import nullcontext +from copy import deepcopy +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union + +import torch +from packaging import version +from tqdm import tqdm +from transformers import GenerationConfig + +from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer +from swift.plugin import Metric +from swift.utils import get_logger, get_node_setting, get_seed +from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, random_uuid) +from .infer_engine import InferEngine +from .patch import patch_auto_config, patch_auto_tokenizer +from .utils import AdapterRequest, InferStreamer, patch_npu_vllm, patch_vllm + +try: + # After setting the environment variables, import vllm. This way of writing allows lint to pass. + os.environ['VLLM_USE_V1'] = os.environ.get('VLLM_USE_V1', '0') + os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '3600' + import vllm + from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams, EngineArgs, LLMEngine +except Exception: + raise + +logger = get_logger() +dtype_mapping = {torch.float16: 'float16', torch.bfloat16: 'bfloat16', torch.float32: 'float32'} + + +class VllmEngine(InferEngine): + + def __init__( + self, + model_id_or_path: str, + torch_dtype: Optional[torch.dtype] = None, + *, + use_async_engine: bool = True, + model_type: Optional[str] = None, + use_hf: Optional[bool] = None, + hub_token: Optional[str] = None, + revision: Optional[str] = None, + # engine_kwargs + gpu_memory_utilization: float = 0.9, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: Optional[int] = None, + max_num_seqs: int = 256, + disable_custom_all_reduce: bool = False, + enforce_eager: bool = False, + limit_mm_per_prompt: Optional[Dict[str, Any]] = None, + device: str = 'auto', + # lora + enable_lora: bool = False, + max_loras: int = 1, + max_lora_rank: int = 16, + enable_prefix_caching: bool = False, + num_infer_workers: int = 1, + enable_sleep_mode: bool = False, + distributed_executor_backend: Optional[str] = None, + quantization: Optional[str] = None, + engine_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + self.use_async_engine = use_async_engine + self.processor = get_model_tokenizer( + model_id_or_path, + torch_dtype, + load_model=False, + download_model=True, + model_type=model_type, + use_hf=use_hf, + hub_token=hub_token, + revision=revision)[1] + self._post_init() + + self._prepare_engine_kwargs( + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + disable_custom_all_reduce=disable_custom_all_reduce, + enforce_eager=enforce_eager, + limit_mm_per_prompt=limit_mm_per_prompt, + enable_lora=enable_lora, + max_loras=max_loras, + max_lora_rank=max_lora_rank, + enable_prefix_caching=enable_prefix_caching, + device=device, + distributed_executor_backend=distributed_executor_backend, + enable_sleep_mode=enable_sleep_mode, + quantization=quantization, + engine_kwargs=engine_kwargs, + ) + nnodes = get_node_setting()[1] + total_infer_workers = num_infer_workers * nnodes + context, npu_context = patch_vllm(world_size=total_infer_workers), nullcontext() + if tensor_parallel_size == 1 or pipeline_parallel_size == 1: + npu_context = patch_npu_vllm(self.engine_args.device) + with context, npu_context: + self._prepare_engine() + self._load_generation_config() + self._fix_vllm_bug() + self.patch_remove_log() + + def _prepare_engine(self) -> None: + with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config): + llm_engine_cls = AsyncLLMEngine if self.use_async_engine else LLMEngine + engine = llm_engine_cls.from_engine_args(self.engine_args) + self.engine = engine + + def _prepare_engine_kwargs( + self, + gpu_memory_utilization: float = 0.9, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: Optional[int] = None, + max_num_seqs: int = 256, + disable_custom_all_reduce: bool = False, + enforce_eager: bool = False, + limit_mm_per_prompt: Optional[Dict[str, Any]] = None, + device: str = 'auto', + enable_lora: bool = False, + max_loras: int = 1, + max_lora_rank: int = 16, + enable_prefix_caching: bool = False, + distributed_executor_backend: Optional[str] = None, + enable_sleep_mode: bool = False, + quantization: Optional[str] = None, + engine_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + if engine_kwargs is None: + engine_kwargs = {} + disable_log_stats = engine_kwargs.pop('disable_log_stats', True) + if self.use_async_engine: + engine_cls = AsyncEngineArgs + engine_kwargs['disable_log_requests'] = True + else: + engine_cls = EngineArgs + parameters = inspect.signature(engine_cls).parameters + if 'enable_lora' in parameters and enable_lora: + engine_kwargs['enable_lora'] = enable_lora + engine_kwargs['max_loras'] = max_loras + engine_kwargs['max_lora_rank'] = max_lora_rank + else: + assert not enable_lora, 'The current version of vLLM does not support `enable_lora`. Please upgrade vLLM.' + + if 'limit_mm_per_prompt' in parameters and limit_mm_per_prompt: + engine_kwargs['limit_mm_per_prompt'] = limit_mm_per_prompt + else: + assert not limit_mm_per_prompt, ( + 'The current version of VLLM does not support `limit_mm_per_prompt`. Please upgrade VLLM.') + if 'enable_sleep_mode' in parameters: + engine_kwargs['enable_sleep_mode'] = enable_sleep_mode + + engine_kwargs['quantization'] = quantization + model_info = self.model_info + if self.config.architectures is None: + architectures = {'deepseek_vl2': ['DeepseekVLV2ForCausalLM']}[self.model_meta.model_type] + engine_kwargs['hf_overrides'] = {'architectures': architectures} + engine_args = engine_cls( + model=self.model_dir, + dtype=dtype_mapping[model_info.torch_dtype], + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + disable_log_stats=disable_log_stats, + disable_custom_all_reduce=disable_custom_all_reduce, + enforce_eager=enforce_eager, + trust_remote_code=True, + enable_prefix_caching=enable_prefix_caching, + distributed_executor_backend=distributed_executor_backend, + device=device, + **engine_kwargs, + ) + if distributed_executor_backend == 'external_launcher': + engine_args.disable_custom_all_reduce = True + self.engine_args = engine_args + self.enable_lora = enable_lora + if max_model_len is not None: + model_info.max_model_len = max_model_len + + def _fix_vllm_bug(self) -> None: + # fix vllm==0.4 bug (very slow) + tokenizer = self.tokenizer + if self._version_ge('0.4') and not tokenizer.__class__.__name__.startswith('Cached'): + _tokenizer_len = len(tokenizer) + __old_len__ = tokenizer.__class__.__len__ + + def __len__(self) -> int: + if self is tokenizer: + return _tokenizer_len + else: + return __old_len__(self) + + tokenizer.__class__.__len__ = __len__ + + def _load_generation_config(self) -> None: + generation_config_path = os.path.join(self.model_dir, 'generation_config.json') + if os.path.isfile(generation_config_path): + generation_config = GenerationConfig.from_pretrained(self.model_dir) + kwargs = generation_config.to_dict() + max_new_tokens = kwargs.get('max_new_tokens') + if max_new_tokens is not None: + kwargs['max_tokens'] = max_new_tokens + top_k = kwargs.get('top_k') + if top_k == 0: + kwargs['top_k'] = -1 + parameters = inspect.signature(SamplingParams).parameters + for k, v in kwargs.copy().items(): + if k not in parameters or v is None: + kwargs.pop(k) + self.generation_config = SamplingParams(**kwargs) + else: + self.generation_config = SamplingParams() + + def _add_stop_words(self, generation_config: SamplingParams, request_config: RequestConfig, + template_meta: TemplateMeta) -> None: + stop_words = (request_config.stop or []) + (self.generation_config.stop or []) + template_meta.stop_words + generation_config.stop = self._get_stop_words(stop_words) + + @staticmethod + def _version_ge(base_version: str): + vllm_version = vllm.__version__ + if vllm_version is None or 'dev' in vllm_version: + return True + return version.parse(vllm_version) >= version.parse(base_version) + + def _add_request(self, + inputs: Dict[str, Any], + generation_config: SamplingParams, + request_id: str, + adapter_request: Optional[AdapterRequest] = None): + kwargs = {} + if self.enable_lora and adapter_request: + from vllm.lora.request import LoRARequest + adapter_name = adapter_request.name + adapter_path = adapter_request.path + if adapter_name in self._adapters_pool: + kwargs['lora_request'] = self._adapters_pool[adapter_name] + else: + kwargs['lora_request'] = LoRARequest( + lora_name=adapter_name, lora_path=adapter_path, lora_int_id=len(self._adapters_pool) + 1) + self._adapters_pool[adapter_name] = kwargs['lora_request'] + input_ids = inputs['input_ids'] + if self._version_ge('0.4.3'): + llm_inputs = {'prompt_token_ids': input_ids} + mm_data = {} + for key in ['images', 'audios', 'videos']: + media_data = inputs.get(key) or [] + if media_data: + if self._version_ge('0.6'): + mm_data = {key.rstrip('s'): media_data[0] if len(media_data) == 1 else media_data} + else: + assert len(media_data) == 1, ( + f'The current version of vllm only supports single {key}. Please upgrade to vllm >= 0.6.0') + mm_data = {key.rstrip('s'): media_data[0]} + if mm_data: + llm_inputs['multi_modal_data'] = mm_data + if self.use_async_engine: + return self.engine.generate(llm_inputs, generation_config, request_id, **kwargs) + else: + return self.engine.add_request(request_id, llm_inputs, generation_config, **kwargs) + else: + if self.use_async_engine: + return self.engine.generate(None, generation_config, request_id, input_ids, **kwargs) + else: + return self.engine.add_request(request_id, None, generation_config, input_ids, **kwargs) + + def _get_logprobs(self, + logprobs_list: Optional[List[Dict[int, float]]], + token_ids: List[int], + top_logprobs: Optional[int] = None) -> Optional[Dict[str, Any]]: + if logprobs_list is None or len(token_ids) == 0: + return None + if len(token_ids) > 0: + logprobs_list = logprobs_list[-len(token_ids):] + for logprobs in logprobs_list: + for token_id, logprob in logprobs.items(): + logprobs[token_id] = logprob.logprob + return super()._get_logprobs(logprobs_list, token_ids, top_logprobs) + + def _prepare_generation_config(self, request_config: RequestConfig) -> SamplingParams: + kwargs = {'max_tokens': request_config.max_tokens} + for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']: + new_value = getattr(request_config, key) + if new_value is None: + kwargs[key] = getattr(self.generation_config, key) + else: + kwargs[key] = new_value + + if request_config.logprobs: + kwargs['logprobs'] = 1 + if request_config.top_logprobs is not None: + kwargs['logprobs'] = max(1, request_config.top_logprobs) + + # TODO: beam search + for key in ['n', 'best_of', 'frequency_penalty', 'presence_penalty', 'seed']: + kwargs[key] = getattr(request_config, key) + + if kwargs.get('seed') is None: + kwargs['seed'] = get_seed() + res = SamplingParams(**kwargs) + res.top_logprobs = request_config.top_logprobs + return res + + @property + def inner_model(self): + return self.engine.model_executor.driver_worker.worker.model_runner.model + + @property + def inner_model_executor(self): + return self.engine.model_executor + + async def _infer_stream_async(self, template: Template, inputs: Dict[str, Any], generation_config: SamplingParams, + **kwargs) -> AsyncIterator[ChatCompletionStreamResponse]: + request_id = random_uuid() + result_generator = self._add_request(inputs, generation_config, request_id, **kwargs) + infer_streamers = [InferStreamer(template) for _ in range(generation_config.n)] + token_idxs = [0 for _ in range(generation_config.n)] + async for result in result_generator: + + is_diff = False + is_finished = False + for output in result.outputs: + output.token_ids = list(output.token_ids) + output.delta_text = infer_streamers[output.index].get_printable_text( + output.token_ids, output.finished()) + output.is_finished = output.finish_reason is not None + is_diff |= bool(output.delta_text) + is_finished |= output.is_finished + if not is_diff and not is_finished: + continue + + num_generated_tokens = sum(len(output.token_ids) for output in result.outputs) + usage_info = self._get_usage_info(len(result.prompt_token_ids), num_generated_tokens) + choices = [] + for output in result.outputs: + logprobs = self._get_logprobs(output.logprobs, output.token_ids[token_idxs[output.index]:], + generation_config.top_logprobs) + token_idxs[output.index] = len(output.token_ids) + toolcall = None + if output.is_finished: + toolcall = self._get_toolcall(template.decode(output.token_ids), template) + choice = ChatCompletionResponseStreamChoice( + index=output.index, + delta=DeltaMessage(role='assistant', content=output.delta_text, tool_calls=toolcall), + finish_reason=output.finish_reason, + logprobs=logprobs) + choices.append(choice) + yield ChatCompletionStreamResponse(model=self.model_name, choices=choices, usage=usage_info, id=request_id) + + def _create_chat_completion_response(self, result, template, generation_config, request_id): + assert result is not None + num_generated_tokens = sum(len(output.token_ids) for output in result.outputs) + usage_info = self._get_usage_info(len(result.prompt_token_ids), num_generated_tokens) + choices = [] + for output in result.outputs: + output.token_ids = list(output.token_ids) + response = template.decode(output.token_ids) + logprobs = self._get_logprobs(output.logprobs, output.token_ids, generation_config.top_logprobs) + toolcall = self._get_toolcall(response, template) + choice = ChatCompletionResponseChoice( + index=output.index, + message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), + finish_reason=output.finish_reason, + logprobs=logprobs) + choices.append(choice) + return ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info, id=request_id) + + async def _infer_full_async( + self, + template: Template, + inputs: Dict[str, Any], + generation_config: SamplingParams, + adapter_request: Optional[AdapterRequest] = None, + ) -> ChatCompletionResponse: + request_id = random_uuid() + result_generator = self._add_request(inputs, generation_config, request_id, adapter_request=adapter_request) + result = None + async for result in result_generator: + pass + return self._create_chat_completion_response(result, template, generation_config, request_id) + + def _batch_infer_stream(self, *args, **kwargs): + if hasattr(self.engine, 'engine'): + self.engine.engine.model_executor.parallel_worker_tasks = None + elif hasattr(self.engine, 'engine_core'): + # vllm>=0.8 + self.engine.engine_core.outputs_queue = None + self.engine.engine_core.queue_task = None + self.engine.output_handler = None + return super()._batch_infer_stream(*args, **kwargs) + + def infer( + self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + metrics: Optional[List[Metric]] = None, + *, + template: Optional[Template] = None, + use_tqdm: Optional[bool] = None, + adapter_request: Optional[AdapterRequest] = None, + ) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: + if self.use_async_engine: + return super().infer( + infer_requests, + request_config, + metrics, + template=template, + use_tqdm=use_tqdm, + adapter_request=adapter_request, + ) + else: + request_config = deepcopy(request_config or RequestConfig()) + if request_config.stream: + raise ValueError('If you want to use stream inference, you need to pass `use_async_engine` as True.') + if use_tqdm is None: + use_tqdm = len(infer_requests) > 1 + if template is None: + template = self.default_template + template.set_mode('vllm') + batched_inputs, error_list = self._batch_encode( + infer_requests, template=template, strict=getattr(self, 'strict', True)) + self.set_default_max_tokens(request_config, batched_inputs) + request_id_list = [] + for inputs in batched_inputs: + request_id = random_uuid() + request_id_list.append(request_id) + generation_config = self._prepare_generation_config(request_config) + self._add_stop_words(generation_config, request_config, template.template_meta) + self._add_request(inputs, generation_config, request_id, adapter_request=adapter_request) + prog_bar = tqdm(total=len(batched_inputs), dynamic_ncols=True, disable=not use_tqdm) + outputs = {} + while self.engine.has_unfinished_requests(): + step_outputs = self.engine.step() + for output in step_outputs: + if output.finished: + outputs[output.request_id] = output + prog_bar.update() + prog_bar.close() + outputs = [outputs[request_id] for request_id in request_id_list] + return [ + self._create_chat_completion_response(result, template, generation_config, request_id) + for request_id, result in zip(request_id_list, outputs) + ] + + async def infer_async( + self, + infer_request: InferRequest, + request_config: Optional[RequestConfig] = None, + *, + template: Optional[Template] = None, + adapter_request: Optional[AdapterRequest] = None, + pre_infer_hook=None, + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: + if not self.use_async_engine: + raise ValueError('If you want to use `infer_async`, you need to pass `use_async_engine` as True.') + request_config = deepcopy(request_config or RequestConfig()) + if template is None: + template = self.default_template + + template.set_mode('vllm') + loop = asyncio.get_running_loop() + with torch.inference_mode(): + inputs = await loop.run_in_executor(None, template.encode, infer_request) + self.set_default_max_tokens(request_config, inputs) + generation_config = self._prepare_generation_config(request_config) + self._add_stop_words(generation_config, request_config, template.template_meta) + kwargs = { + 'template': template, + 'inputs': inputs, + 'generation_config': generation_config, + 'adapter_request': adapter_request, + } + if pre_infer_hook: + kwargs = pre_infer_hook(kwargs) + if request_config.stream: + return self._infer_stream_async(**kwargs) + else: + return await self._infer_full_async(**kwargs) + + @staticmethod + def patch_remove_log(): + from vllm.engine import async_llm_engine + + async_llm_engine._origin_log_task_completion = async_llm_engine._log_task_completion + + def new_log_task_completion(task, error_callback) -> None: + try: + return_value = task.result() + raise AssertionError(f'The engine background task should never finish without an ' + f'exception. {return_value}') + except asyncio.exceptions.CancelledError: + pass + + async_llm_engine._log_task_completion = new_log_task_completion diff --git a/ms-swift/swift/llm/infer/utils.py b/ms-swift/swift/llm/infer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6510b989f3de34e2f70726469f1137e9349c3fc7 --- /dev/null +++ b/ms-swift/swift/llm/infer/utils.py @@ -0,0 +1,147 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re +from copy import deepcopy +from dataclasses import dataclass, field +from typing import List, Literal, Optional + +from swift.plugin import extra_tuners +from swift.tuners import Swift +from swift.utils import get_logger +from ..utils import Messages + +logger = get_logger() + + +@dataclass +class InferCliState: + # None: use default-system. '': not use system. + system: Optional[str] = None + messages: Messages = field(default_factory=list) # not including system + + images: List[str] = field(default_factory=list) + audios: List[str] = field(default_factory=list) + videos: List[str] = field(default_factory=list) + + multiline_mode: bool = False + input_system: bool = False + + def clear(self): + self.messages = [] + self.images = [] + self.audios = [] + self.videos = [] + + def add_query(self, query: str) -> None: + role = 'user' + if query.startswith('tool:'): + role = 'tool' + query = query[len('tool:'):] + self.messages.append({'role': role, 'content': query}) + + def add_response(self, response: str) -> None: + self.messages.append({'role': 'assistant', 'content': response}) + + def to_dict(self): + infer_state = deepcopy(self) + if infer_state.system is not None: + infer_state.messages.insert(0, {'role': 'system', 'content': infer_state.system}) + return { + 'messages': infer_state.messages, + 'images': infer_state.images, + 'audios': infer_state.audios, + 'videos': infer_state.videos + } + + def input_mm_data(self) -> None: + + def _input_mm_file(mm_type: Literal['image', 'video', 'audio']) -> str: + a_an = 'an' if mm_type[0] in {'i', 'a'} else 'a' + return input(f'Input {a_an} {mm_type} path or URL <<< ') + + mm_types = ['image', 'video', 'audio'] + query = self.messages[-1]['content'] + mm_tags = re.findall('|'.join(f'<{mm_type}>' for mm_type in mm_types), query) + # mm_tag -> mm_type/mm_key + mm_mapping = {f'<{mm_type}>': (mm_type, f'{mm_type}s') for mm_type in mm_types} + for mm_tag in mm_tags: + mm_type, mm_key = mm_mapping[mm_tag] + mm_val = getattr(self, mm_key) + mm_val.append(_input_mm_file(mm_type)) + + @staticmethod + def _input_multiline(prompt: str) -> str: + query = '' + stop_words = '#\n' + while True: + text = f'{input(prompt)}\n' + prompt = '' + if text.endswith(stop_words): + query += text[:-len(stop_words)] + break + query += text + return query + + def input_text(self) -> str: + if self.multiline_mode: + addi_prompt = '[MS]' if self.input_system else '[M]' + text = InferCliState._input_multiline(f'<<<{addi_prompt} ') + else: + addi_prompt = '[S]' if self.input_system else '' + text = input(f'<<<{addi_prompt} ') + return text + + def check_query(self, query: str) -> Optional[str]: + query_std = query.strip().lower() + if self.input_system: + if query == 'default-system': + self.system = None + else: + self.system = query + self.input_system = False + query_std = 'clear' + if query_std == 'clear': + self.clear() + return + if query_std == '': + return + if query_std == 'reset-system': + self.input_system = True + return + if query_std == 'multi-line': + self.multiline_mode = True + logger.info('End multi-line input with `#`.') + logger.info('Input `single-line` to switch to single-line input mode.') + return + if query_std == 'single-line': + self.multiline_mode = False + return + return query + + +def prepare_adapter(args, model, adapters=None): + if args.tuner_backend == 'unsloth': + if args.model_meta.is_multimodal: + from unsloth import FastVisionModel as UnslothModel + else: + from unsloth import FastLanguageModel as UnslothModel + UnslothModel.for_inference(model) + return model + if args.train_type in extra_tuners: + tuner = extra_tuners[args.train_type] + else: + tuner = Swift + # compat deploy + adapters = adapters or args.adapters + for adapter in adapters: + model = tuner.from_pretrained(model, adapter) + if args.train_type == 'bone': + # Bone has a problem of float32 matmul with bloat16 in `peft==0.14.0` + model.to(model.dtype) + return model + + +def prepare_model_template(args, **kwargs): + model, processor = args.get_model_processor(**kwargs) + model = prepare_adapter(args, model) + template = args.get_template(processor) + return model, template diff --git a/ms-swift/swift/llm/model/__pycache__/__init__.cpython-310.pyc b/ms-swift/swift/llm/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a252d7270e6c9f10d1966a08310d89b989f9297 Binary files /dev/null and b/ms-swift/swift/llm/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/ms-swift/swift/llm/model/__pycache__/register.cpython-310.pyc b/ms-swift/swift/llm/model/__pycache__/register.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f5ff23b73483c878fa830e68e747eae797a2d33 Binary files /dev/null and b/ms-swift/swift/llm/model/__pycache__/register.cpython-310.pyc differ diff --git a/ms-swift/swift/llm/model/constant.py b/ms-swift/swift/llm/model/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..c504ac3cbb56f9f48e8cb3f6d0d2c73336d4bc80 --- /dev/null +++ b/ms-swift/swift/llm/model/constant.py @@ -0,0 +1,246 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Classification criteria for model_type: same model architecture, tokenizer (get function), template. +from itertools import chain +from typing import List + + +class LLMModelType: + qwen = 'qwen' + qwen2 = 'qwen2' + qwen2_5 = 'qwen2_5' + qwen2_5_math = 'qwen2_5_math' + qwen2_moe = 'qwen2_moe' + qwq_preview = 'qwq_preview' + qwq = 'qwq' + qwen3 = 'qwen3' + qwen3_moe = 'qwen3_moe' + + qwen2_gte = 'qwen2_gte' + + codefuse_qwen = 'codefuse_qwen' + modelscope_agent = 'modelscope_agent' + marco_o1 = 'marco_o1' + + llama = 'llama' + llama3 = 'llama3' + llama3_1 = 'llama3_1' + llama3_2 = 'llama3_2' + reflection = 'reflection' + megrez = 'megrez' + yi = 'yi' + yi_coder = 'yi_coder' + sus = 'sus' + + codefuse_codellama = 'codefuse_codellama' + mengzi3 = 'mengzi3' + ziya = 'ziya' + numina = 'numina' + atom = 'atom' + + chatglm2 = 'chatglm2' + chatglm3 = 'chatglm3' + glm4 = 'glm4' + glm4_0414 = 'glm4_0414' + glm4_z1_rumination = 'glm4_z1_rumination' + + glm_edge = 'glm_edge' + codefuse_codegeex2 = 'codefuse_codegeex2' + codegeex4 = 'codegeex4' + longwriter_llama3_1 = 'longwriter_llama3_1' + + internlm = 'internlm' + internlm2 = 'internlm2' + internlm3 = 'internlm3' + + deepseek = 'deepseek' + deepseek_moe = 'deepseek_moe' + deepseek_v2 = 'deepseek_v2' + deepseek_v2_5 = 'deepseek_v2_5' + deepseek_r1 = 'deepseek_r1' + deepseek_r1_distill = 'deepseek_r1_distill' + + openbuddy_llama = 'openbuddy_llama' + openbuddy_llama3 = 'openbuddy_llama3' + openbuddy_mistral = 'openbuddy_mistral' + openbuddy_mixtral = 'openbuddy_mixtral' + + baichuan = 'baichuan' + baichuan2 = 'baichuan2' + baichuan_m1 = 'baichuan_m1' + + minicpm = 'minicpm' + minicpm_chatml = 'minicpm_chatml' + minicpm3 = 'minicpm3' + minicpm_moe = 'minicpm_moe' + + telechat = 'telechat' + telechat2 = 'telechat2' + + mistral = 'mistral' + zephyr = 'zephyr' + mixtral = 'mixtral' + mistral_nemo = 'mistral_nemo' + mistral_2501 = 'mistral_2501' + wizardlm2 = 'wizardlm2' + wizardlm2_moe = 'wizardlm2_moe' + + phi2 = 'phi2' + phi3_small = 'phi3_small' + phi3 = 'phi3' + phi3_moe = 'phi3_moe' + phi4 = 'phi4' + + minimax = 'minimax' + + gemma = 'gemma' + gemma2 = 'gemma2' + gemma3_text = 'gemma3_text' + + skywork = 'skywork' + skywork_o1 = 'skywork_o1' + + ling = 'ling' + yuan2 = 'yuan2' + orion = 'orion' + xverse = 'xverse' + xverse_moe = 'xverse_moe' + seggpt = 'seggpt' + bluelm = 'bluelm' + c4ai = 'c4ai' + dbrx = 'dbrx' + grok = 'grok' + mamba = 'mamba' + polylm = 'polylm' + aya = 'aya' + moonlight = 'moonlight' + mimo = 'mimo' + + +class BertModelType: + modern_bert = 'modern_bert' + modern_bert_gte = 'modern_bert_gte' + bert = 'bert' + + +class RMModelType: + internlm2_reward = 'internlm2_reward' + qwen2_reward = 'qwen2_reward' + qwen2_5_prm = 'qwen2_5_prm' + qwen2_5_math_reward = 'qwen2_5_math_reward' + llama3_2_reward = 'llama3_2_reward' + gemma_reward = 'gemma_reward' + + +class MLLMModelType: + qwen_vl = 'qwen_vl' + qwen_audio = 'qwen_audio' + qwen2_vl = 'qwen2_vl' + qwen2_5_vl = 'qwen2_5_vl' + qwen2_5_omni = 'qwen2_5_omni' + qwen2_audio = 'qwen2_audio' + qvq = 'qvq' + qwen2_gme = 'qwen2_gme' + ovis1_6 = 'ovis1_6' + ovis1_6_llama3 = 'ovis1_6_llama3' + ovis2 = 'ovis2' + + glm4v = 'glm4v' + glm_edge_v = 'glm_edge_v' + cogvlm = 'cogvlm' + cogagent_vqa = 'cogagent_vqa' + cogagent_chat = 'cogagent_chat' + cogvlm2 = 'cogvlm2' + cogvlm2_video = 'cogvlm2_video' + + internvl = 'internvl' + internvl_phi3 = 'internvl_phi3' + internvl2 = 'internvl2' + internvl2_phi3 = 'internvl2_phi3' + internvl2_5 = 'internvl2_5' + internvl3 = 'internvl3' + xcomposer2 = 'xcomposer2' + xcomposer2_4khd = 'xcomposer2_4khd' + xcomposer2_5 = 'xcomposer2_5' + xcomposer2_5_ol_audio = 'xcomposer2_5_ol_audio' + + llama3_2_vision = 'llama3_2_vision' + llama4 = 'llama4' + llama3_1_omni = 'llama3_1_omni' + + llava1_5_hf = 'llava1_5_hf' + llava1_6_mistral_hf = 'llava1_6_mistral_hf' + llava1_6_vicuna_hf = 'llava1_6_vicuna_hf' + llava1_6_yi_hf = 'llava1_6_yi_hf' + llama3_llava_next_hf = 'llama3_llava_next_hf' + llava_next_qwen_hf = 'llava_next_qwen_hf' + llava_next_video_hf = 'llava_next_video_hf' + llava_next_video_yi_hf = 'llava_next_video_yi_hf' + llava_onevision_hf = 'llava_onevision_hf' + yi_vl = 'yi_vl' + + llava_llama3_1_hf = 'llava_llama3_1_hf' # DaozeZhang + llava_llama3_hf = 'llava_llama3_hf' # xtuner + + llava1_6_mistral = 'llava1_6_mistral' + llava1_6_yi = 'llava1_6_yi' + llava_next_qwen = 'llava_next_qwen' + llama3_llava_next = 'llama3_llava_next' + + deepseek_vl = 'deepseek_vl' + deepseek_vl2 = 'deepseek_vl2' + deepseek_janus = 'deepseek_janus' + deepseek_janus_pro = 'deepseek_janus_pro' + + minicpmv = 'minicpmv' + minicpmv2_5 = 'minicpmv2_5' + minicpmv2_6 = 'minicpmv2_6' + minicpmo2_6 = 'minicpmo2_6' + + minimax_vl = 'minimax_vl' + + mplug_owl2 = 'mplug_owl2' + mplug_owl2_1 = 'mplug_owl2_1' + mplug_owl3 = 'mplug_owl3' + mplug_owl3_241101 = 'mplug_owl3_241101' + doc_owl2 = 'doc_owl2' + + emu3_gen = 'emu3_gen' + emu3_chat = 'emu3_chat' + got_ocr2 = 'got_ocr2' + got_ocr2_hf = 'got_ocr2_hf' + step_audio = 'step_audio' + kimi_vl = 'kimi_vl' + + phi3_vision = 'phi3_vision' + phi4_multimodal = 'phi4_multimodal' + florence = 'florence' + idefics3 = 'idefics3' + paligemma = 'paligemma' + molmo = 'molmo' + molmoe = 'molmoe' + pixtral = 'pixtral' + megrez_omni = 'megrez_omni' + valley = 'valley' + gemma3_vision = 'gemma3_vision' + mistral_2503 = 'mistral_2503' + + +class ModelType(LLMModelType, MLLMModelType, BertModelType, RMModelType): + + @classmethod + def get_model_name_list(cls) -> List[str]: + + def _get_model_name_list(cls): + res = [] + for k in cls.__dict__: + if k.startswith('__'): + continue + value = getattr(cls, k) + if isinstance(value, str): + res.append(value) + return res + + return list( + chain.from_iterable( + _get_model_name_list(model_type_cls) + for model_type_cls in [LLMModelType, MLLMModelType, BertModelType, RMModelType])) diff --git a/ms-swift/swift/llm/model/model/__pycache__/internlm.cpython-310.pyc b/ms-swift/swift/llm/model/model/__pycache__/internlm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2d1c9ba50b5b61ee62caf95b9c8ff0c961b9cf1 Binary files /dev/null and b/ms-swift/swift/llm/model/model/__pycache__/internlm.cpython-310.pyc differ diff --git a/ms-swift/swift/llm/model/model/bert.py b/ms-swift/swift/llm/model/model/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..3686788cfd443fd413c3ede7017ea04cbbfc24cb --- /dev/null +++ b/ms-swift/swift/llm/model/model/bert.py @@ -0,0 +1,66 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from types import MethodType + +import torch.nn.functional as F +from transformers import AutoConfig, AutoModel + +from swift.utils import get_logger +from ..constant import BertModelType +from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_from_local, register_model + +logger = get_logger() + + +def get_model_tokenizer_modern_bert(model_dir, *args, **kwargs): + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + model_config.reference_compile = False + kwargs['model_config'] = model_config + return get_model_tokenizer_from_local(model_dir, *args, **kwargs) + + +register_model( + ModelMeta( + BertModelType.modern_bert, [ + ModelGroup([ + Model('answerdotai/ModernBERT-base', 'answerdotai/ModernBERT-base'), + Model('answerdotai/ModernBERT-large', 'answerdotai/ModernBERT-large'), + ]) + ], + None, + get_model_tokenizer_modern_bert, + requires=['transformers>=4.48'], + tags=['bert'])) + + +def get_model_tokenizer_gte_bert(*args, **kwargs): + kwargs['automodel_class'] = AutoModel + model, tokenizer = get_model_tokenizer_from_local(*args, **kwargs) + if model is not None: + + def _normalizer_hook(module, input, output): + output.last_hidden_state = F.normalize(output.last_hidden_state[:, 0], p=2, dim=1) + return output + + model.register_forward_hook(_normalizer_hook) + return model, tokenizer + + +register_model( + ModelMeta( + BertModelType.modern_bert_gte, + [ModelGroup([ + Model('iic/gte-modernbert-base', 'Alibaba-NLP/gte-modernbert-base'), + ])], + None, + get_model_tokenizer_gte_bert, + requires=['transformers>=4.48'], + tags=['bert', 'embedding'])) + +register_model( + ModelMeta( + BertModelType.bert, [ModelGroup([ + Model('iic/nlp_structbert_backbone_base_std'), + ])], + None, + get_model_tokenizer_from_local, + tags=['bert'])) diff --git a/ms-swift/swift/llm/model/model/deepseek.py b/ms-swift/swift/llm/model/model/deepseek.py new file mode 100644 index 0000000000000000000000000000000000000000..95312a1c65a9fd405aa07ebc4327511feabc0e89 --- /dev/null +++ b/ms-swift/swift/llm/model/model/deepseek.py @@ -0,0 +1,282 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import sys +from typing import Any, Dict + +from swift.llm import TemplateType +from ..constant import LLMModelType, MLLMModelType +from ..model_arch import ModelArch +from ..patcher import patch_output_clone, patch_output_to_input_device +from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model +from ..utils import ModelInfo, git_clone_github, use_submodel_func + +register_model( + ModelMeta( + LLMModelType.deepseek, [ + ModelGroup([ + Model('deepseek-ai/deepseek-llm-7b-base', 'deepseek-ai/deepseek-llm-7b-base'), + Model('deepseek-ai/deepseek-llm-7b-chat', 'deepseek-ai/deepseek-llm-7b-chat'), + Model('deepseek-ai/deepseek-llm-67b-base', 'deepseek-ai/deepseek-llm-67b-base'), + Model('deepseek-ai/deepseek-llm-67b-chat', 'deepseek-ai/deepseek-llm-67b-chat'), + ]), + ModelGroup( + [ + Model('deepseek-ai/deepseek-math-7b-base', 'deepseek-ai/deepseek-math-7b-base'), + Model('deepseek-ai/deepseek-math-7b-instruct', 'deepseek-ai/deepseek-math-7b-instruct'), + Model('deepseek-ai/deepseek-math-7b-rl', 'deepseek-ai/deepseek-math-7b-rl'), + ], + tags=['math'], + ), + ModelGroup( + [ + Model('deepseek-ai/deepseek-coder-1.3b-base', 'deepseek-ai/deepseek-coder-1.3b-base'), + Model('deepseek-ai/deepseek-coder-1.3b-instruct', 'deepseek-ai/deepseek-coder-1.3b-instruct'), + Model('deepseek-ai/deepseek-coder-6.7b-base', 'deepseek-ai/deepseek-coder-6.7b-base'), + Model('deepseek-ai/deepseek-coder-6.7b-instruct', 'deepseek-ai/deepseek-coder-6.7b-instruct'), + Model('deepseek-ai/deepseek-coder-33b-base', 'deepseek-ai/deepseek-coder-33b-base'), + Model('deepseek-ai/deepseek-coder-33b-instruct', 'deepseek-ai/deepseek-coder-33b-instruct'), + ], + tags=['coding'], + ), + ], + TemplateType.deepseek, + get_model_tokenizer_with_flash_attn, + architectures=['LlamaForCausalLM'], + model_arch=ModelArch.llama)) + + +def get_model_tokenizer_deepseek_moe(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) + if model is not None: + # fix dtype bug + mlp_cls = model.model.layers[1].mlp.__class__ + + for module in model.modules(): + if isinstance(module, mlp_cls): + patch_output_to_input_device(module) + return model, tokenizer + + +register_model( + ModelMeta( + LLMModelType.deepseek_moe, + [ + ModelGroup([ + Model('deepseek-ai/deepseek-moe-16b-chat', 'deepseek-ai/deepseek-moe-16b-chat'), + Model('deepseek-ai/deepseek-moe-16b-base', 'deepseek-ai/deepseek-moe-16b-base'), + ], ), + ], + TemplateType.deepseek, + get_model_tokenizer_deepseek_moe, + architectures=['DeepseekForCausalLM'], + )) + +register_model( + ModelMeta( + LLMModelType.deepseek_v2, + [ + ModelGroup([ + Model('deepseek-ai/DeepSeek-Coder-V2-Instruct', 'deepseek-ai/DeepSeek-Coder-V2-Instruct'), + Model('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct', 'deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct'), + Model('deepseek-ai/DeepSeek-Coder-V2-Base', 'deepseek-ai/DeepSeek-Coder-V2-Base'), + Model('deepseek-ai/DeepSeek-Coder-V2-Lite-Base', 'deepseek-ai/DeepSeek-Coder-V2-Lite-Base'), + Model('deepseek-ai/DeepSeek-V2-Lite', 'deepseek-ai/DeepSeek-V2-Lite'), + Model('deepseek-ai/DeepSeek-V2-Lite-Chat', 'deepseek-ai/DeepSeek-V2-Lite-Chat'), + Model('deepseek-ai/DeepSeek-V2', 'deepseek-ai/DeepSeek-V2'), + Model('deepseek-ai/DeepSeek-V2-Chat', 'deepseek-ai/DeepSeek-V2-Chat'), + ]), + ], + TemplateType.deepseek, + get_model_tokenizer_deepseek_moe, + architectures=['DeepseekV2ForCausalLM'], + model_arch=ModelArch.deepseek_v2, + requires=['transformers>=4.39.3'], + )) + +register_model( + ModelMeta( + LLMModelType.deepseek_v2_5, + [ + ModelGroup([ + Model('deepseek-ai/DeepSeek-V2.5', 'deepseek-ai/DeepSeek-V2.5'), + Model('deepseek-ai/DeepSeek-V2.5-1210', 'deepseek-ai/DeepSeek-V2.5-1210'), + Model('deepseek-ai/DeepSeek-V3-Base', 'deepseek-ai/DeepSeek-V3-Base'), + Model('deepseek-ai/DeepSeek-V3', 'deepseek-ai/DeepSeek-V3'), + Model('deepseek-ai/DeepSeek-V3-0324', 'deepseek-ai/DeepSeek-V3-0324'), + ]), + ModelGroup([ + Model('cognitivecomputations/DeepSeek-V3-awq', 'cognitivecomputations/DeepSeek-V3-AWQ'), + Model('cognitivecomputations/DeepSeek-V3-0324-AWQ', 'cognitivecomputations/DeepSeek-V3-0324-AWQ') + ]) + ], + TemplateType.deepseek_v2_5, + get_model_tokenizer_deepseek_moe, + architectures=['DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM'], + model_arch=ModelArch.deepseek_v2, + requires=['transformers>=4.39.3'], + )) + + +def _get_deepseek_vl(processor, llm_prefix, model_dir, *args, **kwargs): + kwargs['tokenizer'] = processor.tokenizer + model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs) + if model: + llm = getattr(model, llm_prefix) + patch_output_clone(llm.model.embed_tokens) + patch_output_to_input_device(llm.model.embed_tokens) + use_submodel_func(model, llm_prefix) + model.generation_config = llm.generation_config + return model, processor + + +def get_model_tokenizer_deepseek_vl(model_dir: str, *args, **kwargs): + # compat with python==3.10 + if sys.version_info.minor >= 10: + import collections + import collections.abc + for type_name in collections.abc.__all__: + setattr(collections, type_name, getattr(collections.abc, type_name)) + local_repo_path = kwargs.get('local_repo_path') + if not local_repo_path: + local_repo_path = git_clone_github('https://github.com/deepseek-ai/DeepSeek-VL') + sys.path.append(local_repo_path) + from deepseek_vl.models import VLChatProcessor + processor = VLChatProcessor.from_pretrained(model_dir) + return _get_deepseek_vl(processor, 'language_model', model_dir, *args, **kwargs) + + +register_model( + ModelMeta( + MLLMModelType.deepseek_vl, + [ + ModelGroup([ + Model('deepseek-ai/deepseek-vl-1.3b-chat', 'deepseek-ai/deepseek-vl-1.3b-chat'), + Model('deepseek-ai/deepseek-vl-7b-chat', 'deepseek-ai/deepseek-vl-7b-chat'), + ], ), + ], + TemplateType.deepseek_vl, + get_model_tokenizer_deepseek_vl, + architectures=['MultiModalityCausalLM'], + model_arch=ModelArch.deepseek_vl, + tags=['vision'], + )) + + +def get_model_tokenizer_deepseek_janus(model_dir: str, *args, **kwargs): + local_repo_path = kwargs.get('local_repo_path') + if not local_repo_path: + local_repo_path = git_clone_github('https://github.com/deepseek-ai/Janus') + sys.path.append(local_repo_path) + from janus.models import VLChatProcessor + + processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_dir) + return _get_deepseek_vl(processor, 'language_model', model_dir, *args, **kwargs) + + +register_model( + ModelMeta( + MLLMModelType.deepseek_janus, + [ + ModelGroup([ + Model('deepseek-ai/Janus-1.3B', 'deepseek-ai/Janus-1.3B'), + ]), + ], + TemplateType.deepseek_janus, + get_model_tokenizer_deepseek_janus, + model_arch=ModelArch.deepseek_janus, + tags=['vision'], + )) + +register_model( + ModelMeta( + MLLMModelType.deepseek_janus_pro, + [ + ModelGroup([ + Model('deepseek-ai/Janus-Pro-1B', 'deepseek-ai/Janus-Pro-1B'), + Model('deepseek-ai/Janus-Pro-7B', 'deepseek-ai/Janus-Pro-7B'), + ]), + ], + TemplateType.deepseek_janus_pro, + get_model_tokenizer_deepseek_janus, + model_arch=ModelArch.deepseek_janus, + tags=['vision'], + )) + + +def get_model_tokenizer_deepseek_vl2(model_dir: str, *args, **kwargs): + local_repo_path = kwargs.get('local_repo_path') + if not local_repo_path: + local_repo_path = git_clone_github('https://github.com/deepseek-ai/DeepSeek-VL2') + sys.path.append(local_repo_path) + try: + from deepseek_vl2.models import DeepseekVLV2Processor + except ImportError: + # compat transformers>=4.42 + import transformers + transformers.models.llama.modeling_llama.LlamaFlashAttention2 = None + from deepseek_vl2.models import DeepseekVLV2Processor + processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_dir) + return _get_deepseek_vl(processor, 'language', model_dir, *args, **kwargs) + + +register_model( + ModelMeta( + MLLMModelType.deepseek_vl2, + [ + ModelGroup([ + Model('deepseek-ai/deepseek-vl2-tiny', 'deepseek-ai/deepseek-vl2-tiny'), + Model('deepseek-ai/deepseek-vl2-small', 'deepseek-ai/deepseek-vl2-small'), + Model('deepseek-ai/deepseek-vl2', 'deepseek-ai/deepseek-vl2'), + ]), + ], + TemplateType.deepseek_vl2, + get_model_tokenizer_deepseek_vl2, + model_arch=ModelArch.deepseek_vl2, + requires=['transformers<4.42'], + architectures=['DeepseekV2ForCausalLM'], + tags=['vision'], + )) + +register_model( + ModelMeta( + LLMModelType.deepseek_r1, + [ + ModelGroup([ + Model('deepseek-ai/DeepSeek-R1', 'deepseek-ai/DeepSeek-R1'), + Model('deepseek-ai/DeepSeek-R1-Zero', 'deepseek-ai/DeepSeek-R1-Zero'), + ]), + ModelGroup([ + Model('cognitivecomputations/DeepSeek-R1-awq', 'cognitivecomputations/DeepSeek-R1-AWQ'), + ]) + ], + TemplateType.deepseek_r1, + get_model_tokenizer_deepseek_moe, + architectures=['DeepseekV3ForCausalLM'], + model_arch=ModelArch.deepseek_v2, + requires=['transformers>=4.39.3'], + )) + +register_model( + ModelMeta( + LLMModelType.deepseek_r1_distill, + [ + ModelGroup([ + Model('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'), + Model('deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', 'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B'), + Model('deepseek-ai/DeepSeek-R1-Distill-Qwen-14B', 'deepseek-ai/DeepSeek-R1-Distill-Qwen-14B'), + Model('deepseek-ai/DeepSeek-R1-Distill-Qwen-32B', 'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B'), + ], + requires=['transformers>=4.37']), + ModelGroup([ + Model('deepseek-ai/DeepSeek-R1-Distill-Llama-8B', 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'), + Model('deepseek-ai/DeepSeek-R1-Distill-Llama-70B', 'deepseek-ai/DeepSeek-R1-Distill-Llama-70B'), + ]), + ], + TemplateType.deepseek_r1, + get_model_tokenizer_with_flash_attn, + architectures=['Qwen2ForCausalLM', 'LlamaForCausalLM'], + model_arch=ModelArch.llama, + )) diff --git a/ms-swift/swift/llm/model/model/llava.py b/ms-swift/swift/llm/model/model/llava.py new file mode 100644 index 0000000000000000000000000000000000000000..2b698e8897eb28b01fbd77e3f4328d538c00038c --- /dev/null +++ b/ms-swift/swift/llm/model/model/llava.py @@ -0,0 +1,391 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import sys +from functools import partial, wraps +from typing import Any, Dict + +from transformers import AutoConfig + +from swift.llm import TemplateType +from ..constant import MLLMModelType +from ..model_arch import ModelArch +from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, + get_model_tokenizer_with_flash_attn, register_model) +from ..utils import ModelInfo, git_clone_github, safe_snapshot_download + + +def get_model_tokenizer_llava_llama(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + from transformers import LlavaForConditionalGeneration, LlavaConfig + + kwargs['model_config'] = LlavaConfig.from_pretrained(model_dir) + kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaForConditionalGeneration + model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs) + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.llava_llama3_hf, + [ + ModelGroup([ + Model('AI-ModelScope/llava-llama-3-8b-v1_1-transformers', 'xtuner/llava-llama-3-8b-v1_1-transformers'), + ]), + ], + TemplateType.llava_llama3_hf, + get_model_tokenizer_llava_llama, + architectures=['LlavaForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.36'], + tags=['vision'], + )) + + +def _patch_llava(model): + if hasattr(model, '__old_generate'): + return + generate = model.generate + model.__old_generate = generate + + @wraps(generate) + def _new_generate(inputs=None, *args, **kwargs): + input_ids = kwargs.pop('input_ids', None) + if inputs is None and input_ids is not None: + inputs = input_ids + return generate(inputs, *args, **kwargs) + + model.generate = _new_generate + + +def get_model_tokenizer_llava_hf(model_dir: str, *args, **kwargs): + from transformers import LlavaForConditionalGeneration + kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaForConditionalGeneration + model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs) + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.llava1_5_hf, + [ + ModelGroup([ + Model('llava-hf/llava-1.5-7b-hf', 'llava-hf/llava-1.5-7b-hf'), + Model('llava-hf/llava-1.5-13b-hf', 'llava-hf/llava-1.5-13b-hf'), + ]), + ], + TemplateType.llava1_5_hf, + get_model_tokenizer_llava_hf, + architectures=['LlavaForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.36'], + tags=['vision'], + )) + + +def get_model_tokenizer_llava_onevision(*args, **kwargs): + from transformers import LlavaOnevisionForConditionalGeneration + kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaOnevisionForConditionalGeneration + return get_model_tokenizer_llava_hf(*args, **kwargs) + + +register_model( + ModelMeta( + MLLMModelType.llava_onevision_hf, + [ + ModelGroup([ + Model('llava-hf/llava-onevision-qwen2-0.5b-ov-hf', 'llava-hf/llava-onevision-qwen2-0.5b-ov-hf'), + Model('llava-hf/llava-onevision-qwen2-7b-ov-hf', 'llava-hf/llava-onevision-qwen2-7b-ov-hf'), + Model('llava-hf/llava-onevision-qwen2-72b-ov-hf', 'llava-hf/llava-onevision-qwen2-72b-ov-hf'), + ], ), + ], + TemplateType.llava_onevision_hf, + get_model_tokenizer_llava_onevision, + architectures=['LlavaOnevisionForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.45'], + tags=['vision', 'video'], + )) + +register_model( + ModelMeta( + MLLMModelType.llava_next_qwen_hf, + [ + ModelGroup([ + Model('llava-hf/llava-next-72b-hf', 'llava-hf/llava-next-72b-hf'), + Model('llava-hf/llava-next-110b-hf', 'llava-hf/llava-next-110b-hf'), + ], ), + ], + TemplateType.llava_next_qwen_hf, + get_model_tokenizer_llava_onevision, + architectures=['LlavaNextForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.39'], + tags=['vision'], + )) + + +def get_model_tokenizer_llava_next(*args, **kwargs): + from transformers import LlavaNextForConditionalGeneration + kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaNextForConditionalGeneration + return get_model_tokenizer_llava_hf(*args, **kwargs) + + +register_model( + ModelMeta( + MLLMModelType.llama3_llava_next_hf, + [ + ModelGroup([ + Model('llava-hf/llama3-llava-next-8b-hf', 'llava-hf/llama3-llava-next-8b-hf'), + ], ), + ], + TemplateType.llama3_llava_next_hf, + get_model_tokenizer_llava_next, + architectures=['LlavaNextForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.39'], + tags=['vision'], + )) + +register_model( + ModelMeta( + MLLMModelType.llava1_6_vicuna_hf, + [ + ModelGroup([ + Model('llava-hf/llava-v1.6-vicuna-7b-hf', 'llava-hf/llava-v1.6-vicuna-7b-hf'), + Model('llava-hf/llava-v1.6-vicuna-13b-hf', 'llava-hf/llava-v1.6-vicuna-13b-hf'), + ], ), + ], + TemplateType.llava1_6_vicuna_hf, + get_model_tokenizer_llava_next, + architectures=['LlavaNextForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.39'], + tags=['vision'], + )) + +register_model( + ModelMeta( + MLLMModelType.llava1_6_mistral_hf, + [ + ModelGroup([ + Model('llava-hf/llava-v1.6-mistral-7b-hf', 'llava-hf/llava-v1.6-mistral-7b-hf'), + ], ), + ], + TemplateType.llava1_6_mistral_hf, + get_model_tokenizer_llava_next, + architectures=['LlavaNextForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.39'], + tags=['vision'], + )) + +register_model( + ModelMeta( + MLLMModelType.llava_llama3_1_hf, + [ + ModelGroup([ + Model('swift/llava-llama3.1-8b'), + ], ), + ], + TemplateType.llava_llama3_1_hf, + get_model_tokenizer_llava_next, + architectures=['LlavaNextForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.41'], + tags=['vision'], + )) + + +def get_model_tokenizer_llava_next_yi(*args, **kwargs): + model, tokenizer = get_model_tokenizer_llava_next(*args, **kwargs) + if model is not None: + model.config.image_token_index = 64003 + return model, tokenizer + + +register_model( + ModelMeta( + MLLMModelType.llava1_6_yi_hf, + [ + ModelGroup([ + Model('llava-hf/llava-v1.6-34b-hf', 'llava-hf/llava-v1.6-34b-hf'), + ], ), + ], + TemplateType.llava1_6_yi_hf, + get_model_tokenizer_llava_next_yi, + architectures=['LlavaNextForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.39'], + tags=['vision'], + )) + + +def get_model_tokenizer_llava_next_video(*args, **kwargs): + from transformers import LlavaNextVideoForConditionalGeneration + kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaNextVideoForConditionalGeneration + return get_model_tokenizer_llava_hf(*args, **kwargs) + + +register_model( + ModelMeta( + MLLMModelType.llava_next_video_hf, + [ + ModelGroup([ + Model('llava-hf/LLaVA-NeXT-Video-7B-DPO-hf', 'llava-hf/LLaVA-NeXT-Video-7B-DPO-hf'), + Model('llava-hf/LLaVA-NeXT-Video-7B-32K-hf', 'llava-hf/LLaVA-NeXT-Video-7B-32K-hf'), + Model('llava-hf/LLaVA-NeXT-Video-7B-hf', 'llava-hf/LLaVA-NeXT-Video-7B-hf'), + ], ), + ], + TemplateType.llava_next_video_hf, + get_model_tokenizer_llava_next_video, + architectures=['LlavaNextVideoForConditionalGeneration'], + model_arch=ModelArch.llava_next_video_hf, + requires=['transformers>=4.42', 'av'], + tags=['video'], + )) + + +def get_model_tokenizer_llava_next_video_yi(*args, **kwargs): + model, tokenizer = get_model_tokenizer_llava_next_video(*args, **kwargs) + if model is not None: + model.config.video_token_index = 64003 + model.config.image_token_index = 64004 + return model, tokenizer + + +register_model( + ModelMeta( + MLLMModelType.llava_next_video_yi_hf, + [ + ModelGroup([ + Model('llava-hf/LLaVA-NeXT-Video-34B-hf', 'llava-hf/LLaVA-NeXT-Video-34B-hf'), + ], ), + ], + TemplateType.llava_next_video_hf, + get_model_tokenizer_llava_next_video_yi, + architectures=['LlavaNextVideoForConditionalGeneration'], + model_arch=ModelArch.llava_next_video_hf, + requires=['transformers>=4.42', 'av'], + tags=['video'], + )) + + +def get_model_tokenizer_llava(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + llm_model_type = kwargs.pop('llm_model_type') + local_repo_path = kwargs.get('local_repo_path') + if not local_repo_path: + if 'next' in llm_model_type: + repo_path = 'https://github.com/LLaVA-VL/LLaVA-NeXT' + else: + repo_path = 'https://github.com/haotian-liu/LLaVA' + local_repo_path = git_clone_github(repo_path) + sys.path.append(local_repo_path) + + if llm_model_type == 'mistral': + from llava.model import LlavaMistralForCausalLM, LlavaMistralConfig + model_config = LlavaMistralConfig.from_pretrained(model_dir) + automodel_class = LlavaMistralForCausalLM + elif 'llama' in llm_model_type: # llama + from llava.model import LlavaLlamaForCausalLM, LlavaConfig + if not hasattr(LlavaLlamaForCausalLM, '__old_forward'): # Avoid double patching + forward = LlavaLlamaForCausalLM.forward + LlavaLlamaForCausalLM.__old_forward = forward + + @wraps(forward) + def _new_forward(*args, **kwargs): + kwargs.pop('cache_position', None) + return forward(*args, **kwargs) + + LlavaLlamaForCausalLM.forward = _new_forward + model_config = LlavaConfig.from_pretrained(model_dir) + automodel_class = LlavaLlamaForCausalLM + else: # qwen + from llava.model import LlavaQwenForCausalLM + automodel_class = LlavaQwenForCausalLM + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + + model_config.mm_vision_tower = safe_snapshot_download('AI-ModelScope/clip-vit-large-patch14-336', check_local=True) + kwargs['model_config'] = model_config + kwargs['automodel_class'] = automodel_class + model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) + + if model is not None: + model.resize_token_embeddings(len(tokenizer)) + vision_tower = model.get_vision_tower() + device_map = str(model_kwargs.get('device_map', str(model.device))) + if not vision_tower.is_loaded: + vision_tower.load_model(device_map=device_map) + if not hasattr(model.config, 'max_sequence_length'): + model.config.max_sequence_length = 2048 + _patch_llava(model) + tokenizer.image_processor = vision_tower.image_processor + return model, tokenizer + + +register_model( + ModelMeta( + MLLMModelType.llama3_llava_next, + [ + ModelGroup([ + Model('AI-ModelScope/llama3-llava-next-8b', 'lmms-lab/llama3-llava-next-8b'), + ], ), + ], + TemplateType.llama3_llava_next, + partial(get_model_tokenizer_llava, llm_model_type='next_llama'), + architectures=['LlavaLlamaForCausalLM'], + model_arch=ModelArch.llava_llama, + requires=['transformers>=4.42', 'av'], + tags=['vision'], + )) + +register_model( + ModelMeta( + MLLMModelType.llava1_6_mistral, + [ + ModelGroup([ + Model('AI-ModelScope/llava-v1.6-mistral-7b', 'liuhaotian/llava-v1.6-mistral-7b'), + ], ), + ], + TemplateType.llava1_6_mistral, + partial(get_model_tokenizer_llava, llm_model_type='mistral'), + requires=['transformers>=4.34'], + architectures=['LlavaMistralForCausalLM'], + model_arch=ModelArch.llava_mistral, + tags=['vision'], + )) + +register_model( + ModelMeta( + MLLMModelType.llava1_6_yi, [ + ModelGroup([ + Model('AI-ModelScope/llava-v1.6-34b', 'liuhaotian/llava-v1.6-34b'), + ], ), + ], + TemplateType.llava1_6_yi, + partial(get_model_tokenizer_llava, llm_model_type='llama'), + requires=['transformers>=4.34'], + architectures=['LlavaLlamaForCausalLM'], + tags=['vision'], + model_arch=None)) + +register_model( + ModelMeta( + MLLMModelType.llava_next_qwen, [ + ModelGroup([ + Model('AI-ModelScope/llava-next-72b', 'lmms-lab/llava-next-72b'), + Model('AI-ModelScope/llava-next-110b', 'lmms-lab/llava-next-110b'), + ], ), + ], + TemplateType.llava_next_qwen, + partial(get_model_tokenizer_llava, llm_model_type='next_qwen'), + architectures=['LlavaQwenForCausalLM'], + requires=['transformers>=4.42', 'av'], + tags=['vision'], + model_arch=None)) diff --git a/ms-swift/swift/llm/model/model/microsoft.py b/ms-swift/swift/llm/model/model/microsoft.py new file mode 100644 index 0000000000000000000000000000000000000000..c682f6efa9080b4f7c5604987cc288650f200d98 --- /dev/null +++ b/ms-swift/swift/llm/model/model/microsoft.py @@ -0,0 +1,234 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from functools import partial +from types import MethodType +from typing import Any, Dict + +from transformers import AutoConfig + +from swift.llm import TemplateType +from swift.utils import get_device, get_env_args +from ..constant import LLMModelType, MLLMModelType +from ..model_arch import ModelArch +from ..patcher import patch_ignore_check_imports, patch_output_clone +from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, + get_model_tokenizer_with_flash_attn, register_model) +from ..utils import ModelInfo, use_submodel_func + + +def get_model_tokenizer_phi3_vision(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + processor_kwargs = {} + if 'num_crops' in kwargs: + processor_kwargs['num_crops'] = get_env_args('num_crops', int, kwargs['num_crops']) + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True, **processor_kwargs) + model, tokenizer = get_model_tokenizer_with_flash_attn( + model_dir, model_info, model_kwargs, load_model, tokenizer=processor.tokenizer, **kwargs) + + if load_model: + patch_output_clone(model.model.vision_embed_tokens.wte) + + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.phi3_vision, + [ + ModelGroup([ + Model('LLM-Research/Phi-3-vision-128k-instruct', 'microsoft/Phi-3-vision-128k-instruct'), + Model('LLM-Research/Phi-3.5-vision-instruct', 'microsoft/Phi-3.5-vision-instruct'), + ]) + ], + TemplateType.phi3_vision, + partial(get_model_tokenizer_phi3_vision, num_crops=4), + architectures=['Phi3VForCausalLM'], + model_arch=ModelArch.phi3_vision, + requires=['transformers>=4.36'], + tags=['vision'], + )) + + +def get_model_tokenizer_phi4_multimodal(*args, **kwargs): + model, processor = get_model_tokenizer_multimodal(*args, **kwargs) + processor.audio_processor.audio_compression_rate = processor.audio_processor.compression_rate + processor.audio_processor.audio_downsample_rate = processor.audio_processor.qformer_compression_rate + processor.audio_processor.audio_feat_stride = processor.audio_processor.feat_stride + del processor.audio_processor.feature_size + del processor.audio_processor.sampling_rate + del processor.audio_processor.padding_value + del processor.__class__.chat_template + processor.chat_template = None + if model is not None: + model.set_lora_adapter(['vision', 'speech']) + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.phi4_multimodal, + [ModelGroup([ + Model('LLM-Research/Phi-4-multimodal-instruct', 'microsoft/Phi-4-multimodal-instruct'), + ])], + TemplateType.phi4_multimodal, + get_model_tokenizer_phi4_multimodal, + architectures=['Phi4MMForCausalLM'], + model_arch=ModelArch.phi4_multimodal, + requires=['transformers>=4.36,<4.49', 'backoff', 'soundfile'], + tags=['vision', 'audio'], + )) + + +def get_model_tokenizer_florence(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + model_config.vision_config.model_type = 'davit' # fix merge-lora + if model_kwargs['device_map'] == 'auto': + model_kwargs['device_map'] = get_device() + kwargs['model_config'] = model_config + with patch_ignore_check_imports(): + model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs) + + if model is not None: + model.vision_tower.enable_checkpoint = True + use_submodel_func(model, 'language_model', ['generate', 'forward']) + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.florence, + [ + # llama2 + ModelGroup([ + Model('AI-ModelScope/Florence-2-base-ft', 'microsoft/Florence-2-base-ft'), + Model('AI-ModelScope/Florence-2-base', 'microsoft/Florence-2-base'), + Model('AI-ModelScope/Florence-2-large', 'microsoft/Florence-2-large'), + Model('AI-ModelScope/Florence-2-large-ft', 'microsoft/Florence-2-large-ft'), + ]), + ], + TemplateType.florence, + get_model_tokenizer_florence, + architectures=['Florence2ForConditionalGeneration'], + model_arch=ModelArch.florence, + tags=['vision'], + )) + + +def get_model_tokenizer_phi3_small(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) + + def rotary_emb(self, query_states, key_states, **kwargs): + q_type = query_states.dtype + k_type = key_states.dtype + query_states, key_states = self.rotory_emb_origin(query_states, key_states, **kwargs) + query_states = query_states.to(q_type) + key_states = key_states.to(k_type) + return query_states, key_states + + if model is not None: + for i in range(32): + re = model.model.layers[i].self_attn.rotary_emb + re.rotory_emb_origin = re.forward + re.forward = MethodType(rotary_emb, re) + return model, tokenizer + + +register_model( + ModelMeta( + LLMModelType.phi3_small, + [ + ModelGroup([ + Model('LLM-Research/Phi-3-small-8k-instruct', 'microsoft/Phi-3-small-8k-instruct'), + Model('LLM-Research/Phi-3-small-128k-instruct', 'microsoft/Phi-3-small-128k-instruct'), + ]), + ], + TemplateType.phi3, + get_model_tokenizer_phi3_small, + architectures=['Phi3SmallForCausalLM'], + model_arch=ModelArch.phi3_small, + requires=['transformers>=4.36'], + )) + + +def get_model_tokenizer_phi(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + # TODO: check + return get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) + + +register_model( + ModelMeta( + LLMModelType.phi2, + [ + ModelGroup([ + Model('AI-ModelScope/phi-2', 'microsoft/phi-2'), + ]), + ], + TemplateType.default, + get_model_tokenizer_phi, + architectures=['PhiForCausalLM'], + model_arch=ModelArch.phi2, + )) + +register_model( + ModelMeta( + LLMModelType.phi3, + [ + ModelGroup([ + Model('LLM-Research/Phi-3-mini-4k-instruct', 'microsoft/Phi-3-mini-4k-instruct'), + Model('LLM-Research/Phi-3-mini-128k-instruct', 'microsoft/Phi-3-mini-128k-instruct'), + Model('LLM-Research/Phi-3-medium-4k-instruct', 'microsoft/Phi-3-medium-4k-instruct'), + Model('LLM-Research/Phi-3-medium-128k-instruct', 'microsoft/Phi-3-medium-128k-instruct'), + Model('LLM-Research/Phi-3.5-mini-instruct', 'microsoft/Phi-3.5-mini-instruct'), + ]), + ModelGroup(Model('LLM-Research/Phi-4-mini-instruct', 'microsoft/Phi-4-mini-instruct')) + ], + TemplateType.phi3, + get_model_tokenizer_with_flash_attn, + architectures=['Phi3ForCausalLM'], + requires=['transformers>=4.36'], + model_arch=ModelArch.phi3, + )) + +register_model( + ModelMeta( + LLMModelType.phi4, + [ + ModelGroup([ + Model('LLM-Research/phi-4', 'microsoft/phi-4'), + ]), + ], + TemplateType.phi4, + get_model_tokenizer_with_flash_attn, + architectures=['Phi3ForCausalLM'], + requires=['transformers>=4.36'], + model_arch=ModelArch.phi3, + )) + +register_model( + ModelMeta( + LLMModelType.phi3_moe, + [ + ModelGroup([ + Model('LLM-Research/Phi-3.5-MoE-instruct', 'microsoft/Phi-3.5-MoE-instruct'), + ]), + ], + TemplateType.phi3, + get_model_tokenizer_with_flash_attn, + architectures=['PhiMoEForCausalLM'], + requires=['transformers>=4.36'], + )) diff --git a/ms-swift/swift/llm/model/model/mistral.py b/ms-swift/swift/llm/model/model/mistral.py new file mode 100644 index 0000000000000000000000000000000000000000..e13f71b5d8af8a3e02e7a127c08a168391cfade9 --- /dev/null +++ b/ms-swift/swift/llm/model/model/mistral.py @@ -0,0 +1,157 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +from swift.llm import TemplateType +from ..constant import LLMModelType, MLLMModelType +from ..model_arch import ModelArch +from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, + get_model_tokenizer_with_flash_attn, register_model) +from ..utils import ModelInfo + +register_model( + ModelMeta( + LLMModelType.mistral, + [ + ModelGroup([ + Model('AI-ModelScope/Mistral-7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.1'), + Model('AI-ModelScope/Mistral-7B-Instruct-v0.2', 'mistralai/Mistral-7B-Instruct-v0.2'), + Model('LLM-Research/Mistral-7B-Instruct-v0.3', 'mistralai/Mistral-7B-Instruct-v0.3'), + Model('AI-ModelScope/Mistral-7B-v0.1', 'mistralai/Mistral-7B-v0.1'), + Model('AI-ModelScope/Mistral-7B-v0.2-hf', 'alpindale/Mistral-7B-v0.2-hf'), + ]), + ModelGroup([ + Model('swift/Codestral-22B-v0.1', 'mistralai/Codestral-22B-v0.1'), + ]), + ], + TemplateType.llama, + get_model_tokenizer_with_flash_attn, + architectures=['MistralForCausalLM'], + model_arch=ModelArch.llama, + requires=['transformers>=4.34'], + )) + +register_model( + ModelMeta( + LLMModelType.mixtral, [ + ModelGroup([ + Model('AI-ModelScope/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mixtral-8x7B-Instruct-v0.1'), + Model('AI-ModelScope/Mixtral-8x7B-v0.1', 'mistralai/Mixtral-8x7B-v0.1'), + Model('AI-ModelScope/Mixtral-8x22B-v0.1', 'mistral-community/Mixtral-8x22B-v0.1'), + ], + requires=['transformers>=4.36']), + ModelGroup([ + Model('AI-ModelScope/Mixtral-8x7b-AQLM-2Bit-1x16-hf', 'ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf'), + ], + requires=['transformers>=4.38', 'aqlm', 'torch>=2.2.0']), + ], + TemplateType.llama, + get_model_tokenizer_with_flash_attn, + architectures=['MixtralForCausalLM'], + model_arch=ModelArch.llama)) + +register_model( + ModelMeta( + LLMModelType.mistral_nemo, [ + ModelGroup([ + Model('AI-ModelScope/Mistral-Small-Instruct-2409', 'mistralai/Mistral-Small-Instruct-2409'), + Model('LLM-Research/Mistral-Large-Instruct-2407', 'mistralai/Mistral-Large-Instruct-2407'), + Model('AI-ModelScope/Mistral-Nemo-Base-2407', 'mistralai/Mistral-Nemo-Base-2407'), + Model('AI-ModelScope/Mistral-Nemo-Instruct-2407', 'mistralai/Mistral-Nemo-Instruct-2407'), + ], + requires=['transformers>=4.43']), + ModelGroup([ + Model('AI-ModelScope/Ministral-8B-Instruct-2410', 'mistralai/Ministral-8B-Instruct-2410'), + ], + requires=['transformers>=4.46']), + ], + TemplateType.mistral_nemo, + get_model_tokenizer_with_flash_attn, + architectures=['MistralForCausalLM'], + model_arch=ModelArch.llama)) + +register_model( + ModelMeta( + LLMModelType.mistral_2501, [ + ModelGroup([ + Model('mistralai/Mistral-Small-24B-Base-2501', 'mistralai/Mistral-Small-24B-Base-2501'), + Model('mistralai/Mistral-Small-24B-Instruct-2501', 'mistralai/Mistral-Small-24B-Instruct-2501'), + ]), + ], + TemplateType.mistral_2501, + get_model_tokenizer_with_flash_attn, + architectures=['MistralForCausalLM'], + model_arch=ModelArch.llama)) + +register_model( + ModelMeta( + LLMModelType.zephyr, + [ + ModelGroup([ + Model('modelscope/zephyr-7b-beta', 'HuggingFaceH4/zephyr-7b-beta'), + ]), + ], + TemplateType.zephyr, + get_model_tokenizer_with_flash_attn, + model_arch=ModelArch.llama, + architectures=['MistralForCausalLM'], + requires=['transformers>=4.34'], + )) + +register_model( + ModelMeta( + LLMModelType.wizardlm2_moe, + [ModelGroup([ + Model('AI-ModelScope/WizardLM-2-8x22B', 'alpindale/WizardLM-2-8x22B'), + ])], + TemplateType.wizardlm2_moe, + get_model_tokenizer_with_flash_attn, + architectures=['MixtralForCausalLM'], + requires=['transformers>=4.36'], + )) + +register_model( + ModelMeta( + LLMModelType.wizardlm2, + [ModelGroup([ + Model('AI-ModelScope/WizardLM-2-7B-AWQ', 'MaziyarPanahi/WizardLM-2-7B-AWQ'), + ])], + TemplateType.wizardlm2, + get_model_tokenizer_with_flash_attn, + architectures=['MistralForCausalLM'], + requires=['transformers>=4.34'], + )) + + +def get_model_tokenizer_mistral_2503(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + try: + from transformers import Mistral3ForConditionalGeneration + except ImportError: + raise ImportError('Please install Gemma3ForConditionalGeneration by running ' + '`pip install git+https://github.com/huggingface/transformers@v4.49.0-Mistral-3`') + + kwargs['automodel_class'] = kwargs['automodel_class'] or Mistral3ForConditionalGeneration + model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs) + + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.mistral_2503, + [ + ModelGroup([ + Model('mistralai/Mistral-Small-3.1-24B-Base-2503', 'mistralai/Mistral-Small-3.1-24B-Base-2503'), + Model('mistralai/Mistral-Small-3.1-24B-Instruct-2503', 'mistralai/Mistral-Small-3.1-24B-Instruct-2503'), + ]), + ], + TemplateType.mistral_2503, + get_model_tokenizer_mistral_2503, + architectures=['Mistral3ForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.49'], + ), ) diff --git a/ms-swift/swift/llm/model/model/mplug.py b/ms-swift/swift/llm/model/model/mplug.py new file mode 100644 index 0000000000000000000000000000000000000000..ff282127f607f8fa26776d0320708cebfb666af1 --- /dev/null +++ b/ms-swift/swift/llm/model/model/mplug.py @@ -0,0 +1,142 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import sys +from functools import partial +from typing import Any, Dict + +from transformers import AutoConfig +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +from swift.llm import TemplateType +from swift.utils import get_logger +from ..constant import MLLMModelType +from ..model_arch import ModelArch +from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model +from ..utils import ModelInfo, git_clone_github, use_submodel_func +from .qwen import get_model_tokenizer_qwen + +logger = get_logger() + + +def get_model_tokenizer_mplug_owl2(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + local_repo_path = kwargs.get('local_repo_path') + if not local_repo_path: + local_repo_path = git_clone_github('https://github.com/X-PLUG/mPLUG-Owl') + local_repo_path = os.path.join(local_repo_path, 'mPLUG-Owl2') + sys.path.append(local_repo_path) + + # register + # https://github.com/X-PLUG/mPLUG-Owl/blob/main/mPLUG-Owl2/mplug_owl2/model/modeling_mplug_owl2.py#L447 + from mplug_owl2 import MPLUGOwl2LlamaForCausalLM + from transformers.models.clip.image_processing_clip import CLIPImageProcessor + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + vocab_size = kwargs.pop('vocab_size', None) + if vocab_size is not None: + model_config.vocab_size = vocab_size + get_model_tokenizer_function = kwargs.pop('get_model_tokenizer_function', get_model_tokenizer_with_flash_attn) + model, tokenizer = get_model_tokenizer_function( + model_dir, model_info, model_kwargs, load_model, model_config=model_config, **kwargs) + logger.info('Please ignore the unimported warning.') + processor = CLIPImageProcessor.from_pretrained(model_dir) + processor.tokenizer = tokenizer + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.mplug_owl2, [ModelGroup([ + Model('iic/mPLUG-Owl2', 'MAGAer13/mplug-owl2-llama2-7b'), + ])], + TemplateType.mplug_owl2, + get_model_tokenizer_mplug_owl2, + model_arch=ModelArch.mplug_owl2, + requires=['transformers<4.35', 'icecream'], + tags=['vision']), ) + +register_model( + ModelMeta( + MLLMModelType.mplug_owl2_1, [ModelGroup([ + Model('iic/mPLUG-Owl2.1', 'Mizukiluke/mplug_owl_2_1'), + ])], + TemplateType.mplug_owl2, + partial( + get_model_tokenizer_mplug_owl2, vocab_size=151851, get_model_tokenizer_function=get_model_tokenizer_qwen), + model_arch=ModelArch.mplug_owl2_1, + requires=['transformers<4.35', 'icecream'], + tags=['vision'])) + + +def get_model_tokenizer_mplug_owl3(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + get_class_from_dynamic_module('configuration_hyper_qwen2.HyperQwen2Config', model_dir) + model_cls = get_class_from_dynamic_module('modeling_mplugowl3.mPLUGOwl3Model', model_dir) + model_cls._no_split_modules = ['SiglipEncoderLayer'] + model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) + processor = model.init_processor(tokenizer) + if model is not None: + func_list = ['generate', 'forward'] + use_submodel_func(model, 'language_model', func_list) + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.mplug_owl3, [ + ModelGroup([ + Model('iic/mPLUG-Owl3-1B-241014', 'mPLUG/mPLUG-Owl3-1B-241014'), + Model('iic/mPLUG-Owl3-2B-241014', 'mPLUG/mPLUG-Owl3-2B-241014'), + Model('iic/mPLUG-Owl3-7B-240728', 'mPLUG/mPLUG-Owl3-7B-240728'), + ]), + ], + TemplateType.mplug_owl3, + get_model_tokenizer_mplug_owl3, + architectures=['mPLUGOwl3Model'], + model_arch=ModelArch.mplug_owl3, + requires=['transformers>=4.36', 'icecream', 'decord'], + tags=['vision', 'video'])) + +register_model( + ModelMeta( + MLLMModelType.mplug_owl3_241101, [ + ModelGroup([ + Model('iic/mPLUG-Owl3-7B-241101', 'mPLUG/mPLUG-Owl3-7B-241101'), + ]), + ], + TemplateType.mplug_owl3_241101, + get_model_tokenizer_mplug_owl3, + architectures=['mPLUGOwl3Model'], + model_arch=ModelArch.mplug_owl3, + requires=['transformers>=4.36', 'icecream'], + tags=['vision', 'video'])) + + +def get_model_tokenizer_doc_owl2(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) + processor = model.init_processor(tokenizer, basic_image_size=504, crop_anchors='grid_12') + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.doc_owl2, [ + ModelGroup([ + Model('iic/DocOwl2', 'mPLUG/DocOwl2'), + ]), + ], + TemplateType.doc_owl2, + get_model_tokenizer_doc_owl2, + architectures=['mPLUGDocOwl2'], + model_arch=ModelArch.doc_owl2, + requires=['transformers>=4.36', 'icecream'], + tags=['vision'])) diff --git a/ms-swift/swift/llm/model/utils.py b/ms-swift/swift/llm/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd48ad1e930f7eb8eb22fe30f392f3fab581268 --- /dev/null +++ b/ms-swift/swift/llm/model/utils.py @@ -0,0 +1,451 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from dataclasses import dataclass +from functools import wraps +from types import MethodType +from typing import Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union + +import torch +from accelerate.utils import find_device +from modelscope.hub.utils.utils import get_cache_dir +from torch import nn +from transformers import PretrainedConfig + +from swift.hub import get_hub +from swift.llm import to_device +from swift.utils import deep_getattr, get_logger, safe_ddp_context, subprocess_run + +logger = get_logger() + +_T = TypeVar('_T') + + +class AttnImpl: + flash_attn = 'flash_attn' + sdpa = 'sdpa' + eager = 'eager' + + attn_impl_keys = ['_attn_implementation', 'attn_implementation', 'llm_attn_implementation'] + use_flash_attn_keys = ['_flash_attn_2_enabled', 'use_flash_attn', '_use_flash_attention_2'] + + @staticmethod + def to_use_flash_attn(attn_impl: Optional[str], auto_value: _T = None) -> Union[bool, _T]: + if attn_impl is None: + return auto_value + return attn_impl == AttnImpl.flash_attn + + @staticmethod + def update_attn_impl(config: PretrainedConfig, + attn_impl: Optional[str], + attn_impl_keys: Optional[List[str]] = None) -> None: + if attn_impl is None: + return + logger.info(f'attn_impl: {attn_impl}') + use_flash_attn = AttnImpl.to_use_flash_attn(attn_impl) + if use_flash_attn: + attn_impl = 'flash_attention_2' + if isinstance(attn_impl_keys, str): + attn_impl_keys = [attn_impl_keys] + attn_impl_keys = attn_impl_keys or AttnImpl.attn_impl_keys + for key in attn_impl_keys: + HfConfigFactory.set_config_attr(config, key, attn_impl, ensure_set=False) + for key in AttnImpl.use_flash_attn_keys: + HfConfigFactory.set_config_attr(config, key, use_flash_attn, ensure_set=False) + + +@dataclass +class ModelInfo: + model_type: str + model_dir: str + torch_dtype: torch.dtype + max_model_len: int + quant_method: Literal['gptq', 'awq', 'bnb', 'aqlm', 'hqq', None] + quant_bits: int + + # extra + rope_scaling: Optional[Dict[str, Any]] = None + config: Optional[PretrainedConfig] = None + task_type: Literal['causal_lm', 'seq_cls', 'embedding', None] = None + num_labels: Optional[int] = None + + def __post_init__(self): + from .register import get_model_name + self.model_name = get_model_name(self.model_dir) + + +class HfConfigFactory: + """This class is used to read config from config.json(maybe params.json also)""" + + @staticmethod + def get_torch_dtype(config: Union[PretrainedConfig, Dict[str, Any]], + quant_info: Dict[str, Any]) -> Optional[torch.dtype]: + for key in ['torch_dtype', 'params_dtype']: + torch_dtype = HfConfigFactory.get_config_attr(config, key) + if torch_dtype is not None: + break + torch_dtype = HfConfigFactory.to_torch_dtype(torch_dtype) + if torch_dtype is None: + torch_dtype = quant_info.get('torch_dtype') + return torch_dtype + + @staticmethod + def _get_config_attrs(config: Union[PretrainedConfig, Dict[str, Any]], + attr_name: str, + parent_key: Optional[str] = None) -> List[Tuple[PretrainedConfig, Any]]: + res = [] + if isinstance(config, dict): + keys = config.keys() + elif isinstance(config, PretrainedConfig): + keys = dir(config) + else: + return [] + + value = deep_getattr(config, attr_name, None) + if value is not None and parent_key in [None, 'language_config', 'llm_config', 'text_config']: + res.append((config, value)) + + for k in keys: + if k.endswith('_config'): + if isinstance(config, dict): + v = config[k] + else: + v = getattr(config, k) + res += HfConfigFactory._get_config_attrs(v, attr_name, k) + return res + + @staticmethod + def get_config_attr(config: Union[PretrainedConfig, Dict[str, Any]], attr_name: str) -> Optional[Any]: + """Get the value of the attribute named attr_name.""" + attrs = HfConfigFactory._get_config_attrs(config, attr_name) + if len(attrs) == 0: + return None + else: + return attrs[0][1] + + @staticmethod + def set_config_attr(config: Union[PretrainedConfig, Dict[str, Any]], + attr_name: str, + value: Any, + ensure_set: bool = True) -> int: + """Set all the attr_name attributes to value.""" + attrs = HfConfigFactory._get_config_attrs(config, attr_name) + if ensure_set and len(attrs) == 0: + attrs.append((config, None)) + for config, _ in attrs: + if isinstance(config, dict): + config[attr_name] = value + else: + setattr(config, attr_name, value) + return len(attrs) + + @staticmethod + def set_model_config_attr(model, attr_name: str, value: Any) -> None: + for module in model.modules(): + if getattr(module, 'config', None) and getattr(module.config, attr_name, value) != value: + setattr(module.config, attr_name, value) + + @staticmethod + def get_max_model_len(config: Union[PretrainedConfig, Dict[str, Any]]) -> Optional[int]: + """Get the max length supported by the model""" + INF = int(1e9) + max_model_len = INF + + possible_keys = [ + 'seq_length', # qwen, chatglm + 'max_position_embeddings', # qwen1.5, llama2 + 'n_positions', # polylm, phi-2 + 'model_max_length', # baichuan2 + # others + 'seq_len', + 'max_seq_len', + 'max_sequence_length', + 'max_seq_length', + ] + for key in possible_keys: + max_len_key = HfConfigFactory.get_config_attr(config, key) + if max_len_key is not None: + max_model_len = min(max_model_len, max_len_key) + if max_model_len == INF: + max_model_len = None + return max_model_len + + @staticmethod + def compat_zero3(config: PretrainedConfig) -> None: + value = HfConfigFactory.get_config_attr(config, 'hidden_size') + try: + # AttributeError: can't set attribute 'hidden_size' + config.hidden_size = value + except AttributeError: + pass + + @staticmethod + def to_torch_dtype(torch_dtype: Union[str, torch.dtype, None]) -> Optional[torch.dtype]: + if torch_dtype is None: + return None + if isinstance(torch_dtype, str): + torch_dtype = eval(f'torch.{torch_dtype}') + return torch_dtype + + @staticmethod + def get_quant_info(config: Union[PretrainedConfig, Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """Get quant_method, quant_bits, dtype. not support hqq/eetq now, support awq/gptq/bnb/aqlm""" + if isinstance(config, dict): + quantization_config = config.get('quantization_config') + else: + quantization_config = getattr(config, 'quantization_config', None) + if quantization_config is None: + return + quantization_config = dict(quantization_config) + quant_method = quantization_config.get('quant_method') + res = {} + if quant_method in {'gptq', 'awq', 'aqlm'}: + res['quant_method'] = quant_method + res['torch_dtype'] = torch.float16 + quant_bits = quantization_config.get('bits') + if quant_bits is not None: + res['quant_bits'] = quant_bits + elif quant_method == 'bitsandbytes': + res['quant_method'] = 'bnb' + load_in_4bit = quantization_config.get('_load_in_4bit') + load_in_8bit = quantization_config.get('_load_in_8bit') + bnb_4bit_compute_dtype = quantization_config.get('bnb_4bit_compute_dtype') + if load_in_4bit: + res['quant_bits'] = 4 + elif load_in_8bit: + res['quant_bits'] = 8 + res['torch_dtype'] = HfConfigFactory.to_torch_dtype(bnb_4bit_compute_dtype) + elif quant_method == 'hqq': + res['quant_method'] = quant_method + res['quant_bits'] = quantization_config['quant_config']['weight_quant_params']['nbits'] + + return res or None + + +def safe_snapshot_download(model_id_or_path: str, + revision: Optional[str] = None, + download_model: bool = True, + use_hf: Optional[bool] = None, + hub_token: Optional[str] = None, + ignore_patterns: Optional[List[str]] = None, + check_local: bool = False, + **kwargs) -> str: + """Download model protected by DDP context + + Args: + model_id_or_path: The model id or model path + revision: The model revision + download_model: Download model bin/safetensors files or not + use_hf: use huggingface or modelscope + + Returns: + model_dir + """ + if check_local: + model_suffix = model_id_or_path.rsplit('/', 1)[-1] + if os.path.exists(model_suffix): + model_dir = os.path.abspath(os.path.expanduser(model_suffix)) + logger.info(f'Loading the model using local model_dir: {model_dir}') + return model_dir + if ignore_patterns is None: + ignore_patterns = [ + '*.zip', '*.gguf', '*.pth', '*.pt', 'consolidated*', 'onnx/*', '*.safetensors.md', '*.msgpack', '*.onnx', + '*.ot', '*.h5' + ] + if not download_model: + ignore_patterns += ['*.bin', '*.safetensors'] + hub = get_hub(use_hf) + if model_id_or_path.startswith('~'): + model_id_or_path = os.path.abspath(os.path.expanduser(model_id_or_path)) + with safe_ddp_context(hash_id=model_id_or_path): + model_path_to_check = '/'.join(model_id_or_path.split(':', 1)) + if os.path.exists(model_id_or_path): + model_dir = model_id_or_path + sub_folder = None + elif os.path.exists(model_path_to_check): + model_dir = model_path_to_check + sub_folder = None + else: + if model_id_or_path.startswith('/'): # startswith + raise ValueError(f"path: '{model_id_or_path}' not found") + model_id_or_path = model_id_or_path.split(':', 1) # get sub_folder + if len(model_id_or_path) == 1: + model_id_or_path = [model_id_or_path[0], None] + model_id_or_path, sub_folder = model_id_or_path + if sub_folder is not None: + kwargs['allow_patterns'] = [f"{sub_folder.rstrip('/')}/*"] + model_dir = hub.download_model(model_id_or_path, revision, ignore_patterns, token=hub_token, **kwargs) + + logger.info(f'Loading the model using model_dir: {model_dir}') + + model_dir = os.path.abspath(os.path.expanduser(model_dir)) + if sub_folder: + model_dir = os.path.join(model_dir, sub_folder) + assert os.path.isdir(model_dir), f'model_dir: {model_dir}' + return model_dir + + +def git_clone_github(github_url: str, + local_repo_name: Optional[str] = None, + branch: Optional[str] = None, + commit_hash: Optional[str] = None) -> str: + if github_url.endswith('.git'): + github_url = github_url[:-4] + git_cache_dir = os.path.join(get_cache_dir(), '_github') + os.makedirs(git_cache_dir, exist_ok=True) + if local_repo_name is None: + github_url = github_url.rstrip('/') + local_repo_name = github_url.rsplit('/', 1)[1] + local_repo_path = os.path.join(git_cache_dir, local_repo_name) + with safe_ddp_context(hash_id=local_repo_path): + if not os.path.exists(local_repo_path): + github_url = f'{github_url}.git' + command = ['git', '-C', git_cache_dir, 'clone', github_url, local_repo_name] + command_str = f"git -C '{git_cache_dir}' clone '{github_url}' {local_repo_name}" + if branch is not None: + command += ['--branch', branch] + command_str += f' --branch {branch}' + logger.info(f'Run the command: `{command_str}`') + subprocess_run(command) + + if commit_hash is not None: + git_cache_path = os.path.join(git_cache_dir, local_repo_name) + command = ['git', '-C', git_cache_path, 'reset', '--hard', commit_hash] + command_str = f"git -C '{git_cache_path}' reset '--hard' {commit_hash}" + logger.info(f'Run the command: `{command_str}`') + subprocess_run(command) + + logger.info(f'local_repo_path: {local_repo_path}') + return local_repo_path + + +def use_submodel_func(model, submodel_name: str, func_list: Optional[List[str]] = None) -> None: + if func_list is None: + func_list = ['generate', 'get_input_embeddings', 'gradient_checkpointing_enable', 'forward'] + submodel = getattr(model, submodel_name) + + def _get_new_func(func_name: str): + _old_func = getattr(submodel, func_name).__func__ + + @wraps(_old_func) + def _new_func(self, *args, **kwargs): + res = _old_func(submodel, *args, **kwargs) + if func_name == 'forward': + device = find_device(args) + if device is None: + device = find_device(kwargs) + if hasattr(res, 'logits'): + res.logits = to_device(res.logits, device) + if hasattr(res, 'loss'): + res.loss = to_device(res.loss, device) + if isinstance(res, dict) and 'last_hidden_state' in res: + res['last_hidden_state'] = to_device(res['last_hidden_state'], device) + return res + + return _new_func + + for key in func_list: + setattr(model, key, MethodType(_get_new_func(key), model)) + if key == 'generate' and model.device != submodel.device: + submodel.__class__.device = model.device + if key == 'forward' and 'generate' in func_list: + setattr(submodel, key, MethodType(_get_new_func(key), submodel)) # fix device_map + + +class InitModelStrategy: + + @staticmethod + def is_uninitialized(param: torch.Tensor) -> bool: + """ + Check if a parameter is uninitialized or has numerically unstable values. + Criteria: + - Tensor has NaN or Inf values + - Tensor stats (mean or std) are outside reasonable range + """ + if param.numel() == 0: + return False + + with torch.no_grad(): + mean_abs = param.abs().mean() + std = param.std() + + # NaN or Inf + if not torch.isfinite(mean_abs) or not torch.isfinite(std): + return True + + # Use empirically safe threshold + MAX_THRESHOLD = 1e7 + if mean_abs > MAX_THRESHOLD or std > MAX_THRESHOLD: + return True + + return False + + @staticmethod + def constant_init(param: torch.Tensor, c: float = 0) -> None: + nn.init.constant_(param, c) + + @staticmethod + def uniform_init(param: torch.Tensor, a: float = -0.1, b: float = 0.1) -> None: + nn.init.uniform_(param, a, b) + + @staticmethod + def normal_init(param: torch.Tensor, mean: float = 0.0, std: float = 0.01) -> None: + nn.init.normal_(param, mean, std) + + @staticmethod + def _init_high_dim(param: torch.Tensor, init_func, *args, **kwargs) -> None: + """Helper for high-dimensional initialization methods.""" + if param.dim() > 1: + init_func(param, *args, **kwargs) + elif param.dim() == 1 and param.size(0) > 0: + InitModelStrategy.constant_init(param) + + @staticmethod + def xavier_uniform_init(param: torch.Tensor) -> None: + InitModelStrategy._init_high_dim(param, nn.init.xavier_uniform_) + + @staticmethod + def xavier_normal_init(param: torch.Tensor) -> None: + InitModelStrategy._init_high_dim(param, nn.init.xavier_normal_) + + @staticmethod + def kaiming_uniform_init(param: torch.Tensor) -> None: + InitModelStrategy._init_high_dim( + param, nn.init.kaiming_uniform_, mode='fan_out', nonlinearity='leaky_relu', a=0.1) + + @staticmethod + def kaiming_normal_init(param: torch.Tensor) -> None: + InitModelStrategy._init_high_dim(param, nn.init.kaiming_normal_, mode='fan_in', nonlinearity='relu') + + @staticmethod + def orthogonal_init(param: torch.Tensor) -> None: + nn.init.orthogonal_(param, gain=1.0) + + _INIT_STRATEGY_MAP = { + 'zero': constant_init, + 'uniform': uniform_init, + 'normal': normal_init, + 'xavier_uniform': xavier_uniform_init, + 'xavier_normal': xavier_normal_init, + 'kaiming_uniform': kaiming_uniform_init, + 'kaiming_normal': kaiming_normal_init, + 'orthogona': orthogonal_init, + } + + @staticmethod + def init_parameters(model: nn.Module, init_strategy: str) -> None: + """Initialize model parameters using the specified strategy. + Args: + model: The model whose parameters to initialize + init_strategy: Name of initialization strategy + """ + if init_strategy not in InitModelStrategy._INIT_STRATEGY_MAP: + raise ValueError(f'Unknown initialization strategy: {init_strategy}') + + logger.info(f'initialization strategy: {init_strategy}') + + init_func = InitModelStrategy._INIT_STRATEGY_MAP[init_strategy] + + for name, param in model.named_parameters(): + if InitModelStrategy.is_uninitialized(param): + logger.info(f'Initializing parameters: {name}.') + init_func(param)