yongqiang commited on
Commit
e088b82
·
0 Parent(s):

init this repo

Browse files
Files changed (49) hide show
  1. .gitattributes +40 -0
  2. .gitignore +4 -0
  3. README.md +97 -0
  4. config.json +0 -0
  5. infer_axmodel.py +100 -0
  6. qwen3_embedding_0.6b_axmodel/model.embed_tokens.weight.bfloat16.bin +3 -0
  7. qwen3_embedding_0.6b_axmodel/model.embed_tokens.weight.float32.bin +3 -0
  8. qwen3_embedding_0.6b_axmodel/model.embed_tokens.weight.npy +3 -0
  9. qwen3_embedding_0.6b_axmodel/qwen3_p128_l0_together.axmodel +3 -0
  10. qwen3_embedding_0.6b_axmodel/qwen3_p128_l10_together.axmodel +3 -0
  11. qwen3_embedding_0.6b_axmodel/qwen3_p128_l11_together.axmodel +3 -0
  12. qwen3_embedding_0.6b_axmodel/qwen3_p128_l12_together.axmodel +3 -0
  13. qwen3_embedding_0.6b_axmodel/qwen3_p128_l13_together.axmodel +3 -0
  14. qwen3_embedding_0.6b_axmodel/qwen3_p128_l14_together.axmodel +3 -0
  15. qwen3_embedding_0.6b_axmodel/qwen3_p128_l15_together.axmodel +3 -0
  16. qwen3_embedding_0.6b_axmodel/qwen3_p128_l16_together.axmodel +3 -0
  17. qwen3_embedding_0.6b_axmodel/qwen3_p128_l17_together.axmodel +3 -0
  18. qwen3_embedding_0.6b_axmodel/qwen3_p128_l18_together.axmodel +3 -0
  19. qwen3_embedding_0.6b_axmodel/qwen3_p128_l19_together.axmodel +3 -0
  20. qwen3_embedding_0.6b_axmodel/qwen3_p128_l1_together.axmodel +3 -0
  21. qwen3_embedding_0.6b_axmodel/qwen3_p128_l20_together.axmodel +3 -0
  22. qwen3_embedding_0.6b_axmodel/qwen3_p128_l21_together.axmodel +3 -0
  23. qwen3_embedding_0.6b_axmodel/qwen3_p128_l22_together.axmodel +3 -0
  24. qwen3_embedding_0.6b_axmodel/qwen3_p128_l23_together.axmodel +3 -0
  25. qwen3_embedding_0.6b_axmodel/qwen3_p128_l24_together.axmodel +3 -0
  26. qwen3_embedding_0.6b_axmodel/qwen3_p128_l25_together.axmodel +3 -0
  27. qwen3_embedding_0.6b_axmodel/qwen3_p128_l26_together.axmodel +3 -0
  28. qwen3_embedding_0.6b_axmodel/qwen3_p128_l27_together.axmodel +3 -0
  29. qwen3_embedding_0.6b_axmodel/qwen3_p128_l2_together.axmodel +3 -0
  30. qwen3_embedding_0.6b_axmodel/qwen3_p128_l3_together.axmodel +3 -0
  31. qwen3_embedding_0.6b_axmodel/qwen3_p128_l4_together.axmodel +3 -0
  32. qwen3_embedding_0.6b_axmodel/qwen3_p128_l5_together.axmodel +3 -0
  33. qwen3_embedding_0.6b_axmodel/qwen3_p128_l6_together.axmodel +3 -0
  34. qwen3_embedding_0.6b_axmodel/qwen3_p128_l7_together.axmodel +3 -0
  35. qwen3_embedding_0.6b_axmodel/qwen3_p128_l8_together.axmodel +3 -0
  36. qwen3_embedding_0.6b_axmodel/qwen3_p128_l9_together.axmodel +3 -0
  37. qwen3_embedding_0.6b_axmodel/qwen3_post.axmodel +3 -0
  38. qwen3_embedding_0.6b_tokenizer/.gitattributes +36 -0
  39. qwen3_embedding_0.6b_tokenizer/1_Pooling/config.json +10 -0
  40. qwen3_embedding_0.6b_tokenizer/README.md +292 -0
  41. qwen3_embedding_0.6b_tokenizer/config.json +30 -0
  42. qwen3_embedding_0.6b_tokenizer/config_sentence_transformers.json +8 -0
  43. qwen3_embedding_0.6b_tokenizer/generation_config.json +6 -0
  44. qwen3_embedding_0.6b_tokenizer/merges.txt +0 -0
  45. qwen3_embedding_0.6b_tokenizer/modules.json +20 -0
  46. qwen3_embedding_0.6b_tokenizer/tokenizer.json +3 -0
  47. qwen3_embedding_0.6b_tokenizer/tokenizer_config.json +240 -0
  48. qwen3_embedding_0.6b_tokenizer/vocab.json +0 -0
  49. utils/infer_func.py +217 -0
.gitattributes ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
39
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ examples/red-panda.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__
2
+ build-output/
3
+ tmp/
4
+ *.safetensors
README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: mit
4
+ base_model:
5
+ - Qwen/Qwen3-Embedding-0.6B
6
+ tags:
7
+ - transformers
8
+ - sentence-transformers
9
+ - sentence-similarity
10
+ - feature-extraction
11
+ - text-embeddings-inference
12
+ pipeline_tag: feature-extraction
13
+ ---
14
+
15
+ # Qwen3-Embedding-0.6B
16
+
17
+ This version of Qwen3-Embedding-0.6B has been converted to run on the Axera NPU using **w8a16** quantization.
18
+
19
+ This model has been optimized with the following LoRA:
20
+
21
+ Compatible with Pulsar2 version: 4.1
22
+
23
+ ## Convert tools links:
24
+
25
+ For those who are interested in model conversion, you can try to export axmodel through the original repo:
26
+ https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
27
+
28
+ [Pulsar2 Link, How to Convert LLM from Huggingface to axmodel](https://pulsar2-docs.readthedocs.io/en/latest/appendix/build_llm.html)
29
+
30
+ ## Support Platform
31
+
32
+ - AX650
33
+ - AX650N DEMO Board
34
+ - [M4N-Dock(爱芯派Pro)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
35
+ - [M.2 Accelerator card](https://axcl-docs.readthedocs.io/zh-cn/latest/doc_guide_hardware.html)
36
+
37
+ ### Each subgraph is time-consuming
38
+ ```sh
39
+ g1: 5.561 ms
40
+ g2: 9.140 ms
41
+ g3: 12.757 ms
42
+ g4: 16.446 ms
43
+ g5: 21.392 ms
44
+ g6: 23.712 ms
45
+ g7: 27.174 ms
46
+ g8: 30.897 ms
47
+ g9: 34.829 ms
48
+ ```
49
+ |Chips | ttft | w8a16 |
50
+ |--|--|--|--|
51
+ |AX650| 155.708 ms | 6.42 tokens/sec|
52
+ |AX650| 5093.42 ms | 0.19 tokens/sec|
53
+
54
+ Longest time consumption: 181.908 ms
55
+ Shortest time consumption: 5.561 ms
56
+ LayerNum: 28
57
+
58
+ ## How to use
59
+
60
+ Download all files from this repository to the device.
61
+
62
+ **If you using AX650 Board**
63
+ ```
64
+ root@ax650 ~/yongqiang/push_hugging_face/Qwen3-Embedding-0.6B # tree -L 1
65
+ .
66
+ ├── config.json
67
+ ├── infer_axmodel.py
68
+ ├── qwen3_embedding_0.6b_axmodel
69
+ ├── qwen3_embedding_0.6b_tokenizer
70
+ ├── README.md
71
+ └── utils
72
+
73
+ 3 directories, 3 files
74
+ ```
75
+
76
+ #### Install transformer
77
+
78
+ ```
79
+ # Requires transformers>=4.51.0
80
+ pip install transformers==4.51.0
81
+ ```
82
+
83
+ #### Inference with AX650 Host, such as M4N-Dock(爱芯派Pro) or AX650N DEMO Board
84
+
85
+ ```
86
+ $ python3 infer_axmodel.py
87
+ Model loaded successfully!
88
+ slice_indices: [0]
89
+ Slice prefill done: 0
90
+ slice_indices: [0]
91
+ Slice prefill done: 0
92
+ slice_indices: [0]
93
+ Slice prefill done: 0
94
+ slice_indices: [0]
95
+ Slice prefill done: 0
96
+ [[0.7555467486381531, 0.1756950318813324], [0.4137178063392639, 0.4459586441516876]]
97
+ ```
config.json ADDED
File without changes
infer_axmodel.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requires transformers>=4.51.0
2
+ import os
3
+ import numpy as np
4
+ from axengine import InferenceSession
5
+ from ml_dtypes import bfloat16
6
+ from utils.infer_func import InferManager
7
+ import argparse
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import Tensor
12
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
13
+
14
+
15
+ def last_token_pool(last_hidden_states: Tensor,
16
+ attention_mask: Tensor) -> Tensor:
17
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
18
+ if left_padding:
19
+ return last_hidden_states[:, -1]
20
+ else:
21
+ sequence_lengths = attention_mask.sum(dim=1) - 1
22
+ batch_size = last_hidden_states.shape[0]
23
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
24
+
25
+
26
+ def get_detailed_instruct(task_description: str, query: str) -> str:
27
+ return f'Instruct: {task_description}\nQuery:{query}'
28
+
29
+
30
+ if __name__ == "__main__":
31
+
32
+ """
33
+ python3 infer_axmodel.py
34
+ """
35
+
36
+ prompt = None
37
+ parser = argparse.ArgumentParser(description="Model configuration parameters")
38
+ parser.add_argument("--hf_model", type=str, default="./qwen3_embedding_0.6b_tokenizer/",
39
+ help="Path to HuggingFace model")
40
+ parser.add_argument("--axmodel_path", type=str, default="./qwen3_embedding_0.6b_axmodel/",
41
+ help="Path to save compiled axmodel of llama model")
42
+ args = parser.parse_args()
43
+
44
+ hf_model_path = args.hf_model
45
+ axmodel_path = args.axmodel_path
46
+
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ embeds = np.load(os.path.join(axmodel_path, "model.embed_tokens.weight.npy"))
49
+
50
+ # Each query must come with a one-sentence instruction that describes the task
51
+ task = 'Given a web search query, retrieve relevant passages that answer the query'
52
+
53
+ queries = [
54
+ get_detailed_instruct(task, 'What is the capital of China?'),
55
+ get_detailed_instruct(task, 'Explain gravity')
56
+ ]
57
+ # No need to add instruction for retrieval documents
58
+ documents = [
59
+ "The capital of China is Beijing.",
60
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
61
+ ]
62
+ input_texts = queries + documents
63
+
64
+ tokenizer = AutoTokenizer.from_pretrained(hf_model_path, padding_side='left')
65
+ cfg = AutoConfig.from_pretrained(hf_model_path)
66
+
67
+ max_length = 8192
68
+
69
+ # Tokenize the input texts
70
+ batch_dict = tokenizer(
71
+ input_texts,
72
+ padding=True,
73
+ truncation=True,
74
+ max_length=max_length,
75
+ return_tensors="pt",
76
+ )
77
+ batch_dict.to(device)
78
+
79
+ input_ids = batch_dict['input_ids']
80
+ attention_mask = batch_dict['attention_mask']
81
+ inputs_embeds = np.take(embeds, input_ids.cpu().numpy(), axis=0)
82
+ prefill_data = inputs_embeds
83
+ prefill_data = prefill_data.astype(bfloat16)
84
+ token_ids = input_ids[0].cpu().numpy().tolist()
85
+ token_len = len(token_ids)
86
+
87
+ batch_num, seq_len, seq_dim = inputs_embeds.shape
88
+
89
+ imer = InferManager(cfg, axmodel_path)
90
+ last_hidden_state = np.zeros((batch_num, seq_len, seq_dim), dtype=bfloat16)
91
+ for batch_idx in range(batch_num):
92
+ last_hidden_state[batch_idx] = imer.prefill(tokenizer, token_ids, prefill_data[batch_idx], slice_len=128)
93
+
94
+ embeddings = last_token_pool(torch.from_numpy(last_hidden_state.astype(np.float32)), batch_dict['attention_mask'])
95
+ # normalize embeddings
96
+ embeddings = F.normalize(embeddings, p=2, dim=1)
97
+ scores = (embeddings[:2] @ embeddings[2:].T)
98
+ print(scores.tolist())
99
+ # Torch: [[0.7645568251609802, 0.14142508804798126], [0.13549736142158508, 0.5999549627304077]]
100
+ # Axmod: [[0.755547046661377, 0.17569507658481598], [0.4137181341648102, 0.4459584951400757]]
qwen3_embedding_0.6b_axmodel/model.embed_tokens.weight.bfloat16.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a55b140d86852835bd18d8200222a9f302340730f0670eb7e23a4895e5489033
3
+ size 310618112
qwen3_embedding_0.6b_axmodel/model.embed_tokens.weight.float32.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7a027c062fb61cd505e046bc832345be155e1eb2fab629675cebe7973646c85
3
+ size 621236224
qwen3_embedding_0.6b_axmodel/model.embed_tokens.weight.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bbdc47aee1b4cdb97a42a255306d4e0a1cb52f797bfdc32f94469eb0cd0744e
3
+ size 621236352
qwen3_embedding_0.6b_axmodel/qwen3_p128_l0_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:025e17a92f3f19d58a36ef119294598073c4ccdc794aa9d4a2845a99b0c6b53d
3
+ size 28019747
qwen3_embedding_0.6b_axmodel/qwen3_p128_l10_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95d809d6cc6889517b1aad7a4e62e51ffbf75580dda5ceafb667dbd5ac10ba6e
3
+ size 28019779
qwen3_embedding_0.6b_axmodel/qwen3_p128_l11_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fb730805877eceea4aa037694bc2abb830fa960c6666d24b976d7ae35c058d0
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l12_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eef7266dff6af522a0a63095067a1c7823a9a1213e7bd498bcdb97f2814523ba
3
+ size 28019427
qwen3_embedding_0.6b_axmodel/qwen3_p128_l13_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9322be28dad0729b006238364297659594e2193a97a37556fb06f63d3fec9fa0
3
+ size 28019459
qwen3_embedding_0.6b_axmodel/qwen3_p128_l14_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5a8f239e4a4e793e0bdf86226c08c2089f5199118bd38d2be4957f9b7023dda
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l15_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8767d8ce2a3a20049f30da7982b3cb15799e6e0959b3deb484db15fef426010b
3
+ size 28019619
qwen3_embedding_0.6b_axmodel/qwen3_p128_l16_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a80f39502798ca50e070ace41527fb176232ce315d6175ea65ec8e65467112fa
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l17_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:108d4abc6ceef53bf17786e0512d6c4e2de8d531cc14b1192fb9923ceeb2c10f
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l18_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8e07c47bd98d9062c3c5bc48fc7057f363cbb5b31c62540346814299b1843f4
3
+ size 28019299
qwen3_embedding_0.6b_axmodel/qwen3_p128_l19_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a713464dcbd8187f27cecc4d6917acb8726ab9cffeaf6ed1e9c2bed6059cdd0
3
+ size 28019075
qwen3_embedding_0.6b_axmodel/qwen3_p128_l1_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c37fefd3be1f8699961f1b5d6c164a74960ce90a2b219514add7ce960534a292
3
+ size 28019363
qwen3_embedding_0.6b_axmodel/qwen3_p128_l20_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad7cd29e2ea7644f6ff39a179efb6b646743fc47a3c4049877db258b8072fbd1
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l21_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff616a72c372edc0068909aea76abd76b1c5121fa86e0052d93dfe42fdd247b7
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l22_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c77fb82587bbc3c682c28cae6bf3cc3ac290bd59ffd60b6f03c875970c3518b8
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l23_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c1d6057741c34a9b871b4ae2852c3e9616b2423abd9b1035a26c003d0158645
3
+ size 28019491
qwen3_embedding_0.6b_axmodel/qwen3_p128_l24_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b14f1c52826b17c7eb54f6e2ab9c0a5c353cf34b0f90a252d2c3d13114d0a284
3
+ size 28020259
qwen3_embedding_0.6b_axmodel/qwen3_p128_l25_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b81f7122c92a9b04a19581d4bbba8f8392037e79db1ae3ead467f2f6dcb9975
3
+ size 28019395
qwen3_embedding_0.6b_axmodel/qwen3_p128_l26_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c5a70d619eebbd1b8a072902c807022c2cabd421018913fb47a44627132eb98
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l27_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26b447aa2df812d621e272c51115d94d46c8f8cb9bcce9281d86d2f9c2612211
3
+ size 28020547
qwen3_embedding_0.6b_axmodel/qwen3_p128_l2_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:167e1b7111d86d7d0aba4879314fc59b93b2385242caddea37ae67c54b34cc4e
3
+ size 28019107
qwen3_embedding_0.6b_axmodel/qwen3_p128_l3_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d929fe10751a50dae6fe2379149f912709b55f597c20686bcf2b4c812bb46d9
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l4_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40dbedcb1f86e618d2f7cf68d44dfce85e120cdea9475ebb43f1d04852a536ed
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l5_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ee25b58ea5b0ae8a14a031c6f8bd87c826cf5419d9be3c542b9032cfb18b929
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l6_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f20cdb2296614658bc35171dda33c9ef84d71771620955dd30fe0b700f88cf0a
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l7_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8b81646a9906eef6eb69eda019cc8294db770dbaca915b7ea0e3a378f35f988
3
+ size 28019555
qwen3_embedding_0.6b_axmodel/qwen3_p128_l8_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba4c0c875d63cafaff17f77378fa50a4ae4ad87e7936ffc3f15b06082d8bb68c
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_p128_l9_together.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b1d57cb66e5e04ddce9a5c8f7b308cc7995fedb4f92e9038260939d97d151cf
3
+ size 28018723
qwen3_embedding_0.6b_axmodel/qwen3_post.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b4557f9ba5f22fce6e1584d4b1119e0bbd4dc0f0d241233d1e8ee7f39775757
3
+ size 169711178
qwen3_embedding_0.6b_tokenizer/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
qwen3_embedding_0.6b_tokenizer/1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 1024,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": false,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": true,
9
+ "include_prompt": true
10
+ }
qwen3_embedding_0.6b_tokenizer/README.md ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model:
4
+ - Qwen/Qwen3-0.6B-Base
5
+ tags:
6
+ - transformers
7
+ - sentence-transformers
8
+ - sentence-similarity
9
+ - feature-extraction
10
+ - text-embeddings-inference
11
+ ---
12
+ # Qwen3-Embedding-0.6B
13
+
14
+ <p align="center">
15
+ <img src="https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/logo_qwen3.png" width="400"/>
16
+ <p>
17
+
18
+ ## Highlights
19
+
20
+ The Qwen3 Embedding model series is the latest proprietary model of the Qwen family, specifically designed for text embedding and ranking tasks. Building upon the dense foundational models of the Qwen3 series, it provides a comprehensive range of text embeddings and reranking models in various sizes (0.6B, 4B, and 8B). This series inherits the exceptional multilingual capabilities, long-text understanding, and reasoning skills of its foundational model. The Qwen3 Embedding series represents significant advancements in multiple text embedding and ranking tasks, including text retrieval, code retrieval, text classification, text clustering, and bitext mining.
21
+
22
+ **Exceptional Versatility**: The embedding model has achieved state-of-the-art performance across a wide range of downstream application evaluations. The 8B size embedding model ranks **No.1** in the MTEB multilingual leaderboard (as of June 5, 2025, score **70.58**), while the reranking model excels in various text retrieval scenarios.
23
+
24
+ **Comprehensive Flexibility**: The Qwen3 Embedding series offers a full spectrum of sizes (from 0.6B to 8B) for both embedding and reranking models, catering to diverse use cases that prioritize efficiency and effectiveness. Developers can seamlessly combine these two modules. Additionally, the embedding model allows for flexible vector definitions across all dimensions, and both embedding and reranking models support user-defined instructions to enhance performance for specific tasks, languages, or scenarios.
25
+
26
+ **Multilingual Capability**: The Qwen3 Embedding series offer support for over 100 languages, thanks to the multilingual capabilites of Qwen3 models. This includes various programming languages, and provides robust multilingual, cross-lingual, and code retrieval capabilities.
27
+
28
+ ## Model Overview
29
+
30
+ **Qwen3-Embedding-0.6B** has the following features:
31
+
32
+ - Model Type: Text Embedding
33
+ - Supported Languages: 100+ Languages
34
+ - Number of Paramaters: 0.6B
35
+ - Context Length: 32k
36
+ - Embedding Dimension: Up to 1024, supports user-defined output dimensions ranging from 32 to 1024
37
+
38
+ For more details, including benchmark evaluation, hardware requirements, and inference performance, please refer to our [blog](https://qwenlm.github.io/blog/qwen3-embedding/), [GitHub](https://github.com/QwenLM/Qwen3-Embedding).
39
+
40
+ ## Qwen3 Embedding Series Model list
41
+
42
+ | Model Type | Models | Size | Layers | Sequence Length | Embedding Dimension | MRL Support | Instruction Aware |
43
+ |------------------|----------------------|------|--------|-----------------|---------------------|-------------|----------------|
44
+ | Text Embedding | [Qwen3-Embedding-0.6B](https://huggingface.co/Qwen/Qwen3-Embedding-0.6B) | 0.6B | 28 | 32K | 1024 | Yes | Yes |
45
+ | Text Embedding | [Qwen3-Embedding-4B](https://huggingface.co/Qwen/Qwen3-Embedding-4B) | 4B | 36 | 32K | 2560 | Yes | Yes |
46
+ | Text Embedding | [Qwen3-Embedding-8B](https://huggingface.co/Qwen/Qwen3-Embedding-8B) | 8B | 36 | 32K | 4096 | Yes | Yes |
47
+ | Text Reranking | [Qwen3-Reranker-0.6B](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B) | 0.6B | 28 | 32K | - | - | Yes |
48
+ | Text Reranking | [Qwen3-Reranker-4B](https://huggingface.co/Qwen/Qwen3-Reranker-4B) | 4B | 36 | 32K | - | - | Yes |
49
+ | Text Reranking | [Qwen3-Reranker-8B](https://huggingface.co/Qwen/Qwen3-Reranker-8B) | 8B | 36 | 32K | - | - | Yes |
50
+
51
+ > **Note**:
52
+ > - `MRL Support` indicates whether the embedding model supports custom dimensions for the final embedding.
53
+ > - `Instruction Aware` notes whether the embedding or reranking model supports customizing the input instruction according to different tasks.
54
+ > - Our evaluation indicates that, for most downstream tasks, using instructions (instruct) typically yields an improvement of 1% to 5% compared to not using them. Therefore, we recommend that developers create tailored instructions specific to their tasks and scenarios. In multilingual contexts, we also advise users to write their instructions in English, as most instructions utilized during the model training process were originally written in English.
55
+
56
+ ## Usage
57
+
58
+ With Transformers versions earlier than 4.51.0, you may encounter the following error:
59
+ ```
60
+ KeyError: 'qwen3'
61
+ ```
62
+
63
+ ### Sentence Transformers Usage
64
+
65
+ ```python
66
+ # Requires transformers>=4.51.0
67
+ # Requires sentence-transformers>=2.7.0
68
+
69
+ from sentence_transformers import SentenceTransformer
70
+
71
+ # Load the model
72
+ model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")
73
+
74
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving,
75
+ # together with setting `padding_side` to "left":
76
+ # model = SentenceTransformer(
77
+ # "Qwen/Qwen3-Embedding-0.6B",
78
+ # model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto"},
79
+ # tokenizer_kwargs={"padding_side": "left"},
80
+ # )
81
+
82
+ # The queries and documents to embed
83
+ queries = [
84
+ "What is the capital of China?",
85
+ "Explain gravity",
86
+ ]
87
+ documents = [
88
+ "The capital of China is Beijing.",
89
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
90
+ ]
91
+
92
+ # Encode the queries and documents. Note that queries benefit from using a prompt
93
+ # Here we use the prompt called "query" stored under `model.prompts`, but you can
94
+ # also pass your own prompt via the `prompt` argument
95
+ query_embeddings = model.encode(queries, prompt_name="query")
96
+ document_embeddings = model.encode(documents)
97
+
98
+ # Compute the (cosine) similarity between the query and document embeddings
99
+ similarity = model.similarity(query_embeddings, document_embeddings)
100
+ print(similarity)
101
+ # tensor([[0.7646, 0.1414],
102
+ # [0.1355, 0.6000]])
103
+ ```
104
+
105
+ ### Transformers Usage
106
+
107
+ ```python
108
+ # Requires transformers>=4.51.0
109
+
110
+ import torch
111
+ import torch.nn.functional as F
112
+
113
+ from torch import Tensor
114
+ from transformers import AutoTokenizer, AutoModel
115
+
116
+
117
+ def last_token_pool(last_hidden_states: Tensor,
118
+ attention_mask: Tensor) -> Tensor:
119
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
120
+ if left_padding:
121
+ return last_hidden_states[:, -1]
122
+ else:
123
+ sequence_lengths = attention_mask.sum(dim=1) - 1
124
+ batch_size = last_hidden_states.shape[0]
125
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
126
+
127
+
128
+ def get_detailed_instruct(task_description: str, query: str) -> str:
129
+ return f'Instruct: {task_description}\nQuery:{query}'
130
+
131
+ # Each query must come with a one-sentence instruction that describes the task
132
+ task = 'Given a web search query, retrieve relevant passages that answer the query'
133
+
134
+ queries = [
135
+ get_detailed_instruct(task, 'What is the capital of China?'),
136
+ get_detailed_instruct(task, 'Explain gravity')
137
+ ]
138
+ # No need to add instruction for retrieval documents
139
+ documents = [
140
+ "The capital of China is Beijing.",
141
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
142
+ ]
143
+ input_texts = queries + documents
144
+
145
+ tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
146
+ model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')
147
+
148
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving.
149
+ # model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B', attn_implementation="flash_attention_2", torch_dtype=torch.float16).cuda()
150
+
151
+ max_length = 8192
152
+
153
+ # Tokenize the input texts
154
+ batch_dict = tokenizer(
155
+ input_texts,
156
+ padding=True,
157
+ truncation=True,
158
+ max_length=max_length,
159
+ return_tensors="pt",
160
+ )
161
+ batch_dict.to(model.device)
162
+ outputs = model(**batch_dict)
163
+ embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
164
+
165
+ # normalize embeddings
166
+ embeddings = F.normalize(embeddings, p=2, dim=1)
167
+ scores = (embeddings[:2] @ embeddings[2:].T)
168
+ print(scores.tolist())
169
+ # [[0.7645568251609802, 0.14142508804798126], [0.13549736142158508, 0.5999549627304077]]
170
+ ```
171
+
172
+ ### vLLM Usage
173
+
174
+ ```python
175
+ # Requires vllm>=0.8.5
176
+ import torch
177
+ import vllm
178
+ from vllm import LLM
179
+
180
+ def get_detailed_instruct(task_description: str, query: str) -> str:
181
+ return f'Instruct: {task_description}\nQuery:{query}'
182
+
183
+ # Each query must come with a one-sentence instruction that describes the task
184
+ task = 'Given a web search query, retrieve relevant passages that answer the query'
185
+
186
+ queries = [
187
+ get_detailed_instruct(task, 'What is the capital of China?'),
188
+ get_detailed_instruct(task, 'Explain gravity')
189
+ ]
190
+ # No need to add instruction for retrieval documents
191
+ documents = [
192
+ "The capital of China is Beijing.",
193
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
194
+ ]
195
+ input_texts = queries + documents
196
+
197
+ model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
198
+
199
+ outputs = model.embed(input_texts)
200
+ embeddings = torch.tensor([o.outputs.embedding for o in outputs])
201
+ scores = (embeddings[:2] @ embeddings[2:].T)
202
+ print(scores.tolist())
203
+ # [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
204
+ ```
205
+
206
+ 📌 **Tip**: We recommend that developers customize the `instruct` according to their specific scenarios, tasks, and languages. Our tests have shown that in most retrieval scenarios, not using an `instruct` on the query side can lead to a drop in retrieval performance by approximately 1% to 5%.
207
+
208
+ ### Text Embeddings Inference (TEI) Usage
209
+
210
+ You can either run / deploy TEI on NVIDIA GPUs as:
211
+
212
+ ```bash
213
+ docker run --gpus all -p 8080:80 -v hf_cache:/data --pull always ghcr.io/huggingface/text-embeddings-inference:cpu-1.7.2 --model-id Qwen/Qwen3-Embedding-0.6B --dtype float16
214
+ ```
215
+
216
+ Or on CPU devices as:
217
+
218
+ ```bash
219
+ docker run -p 8080:80 -v hf_cache:/data --pull always ghcr.io/huggingface/text-embeddings-inference:1.7.2 --model-id Qwen/Qwen3-Embedding-0.6B
220
+ ```
221
+
222
+ And then, generate the embeddings sending a HTTP POST request as:
223
+
224
+ ```bash
225
+ curl http://localhost:8080/embed \
226
+ -X POST \
227
+ -d '{"inputs": ["Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: What is the capital of China?", "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: Explain gravity"]}' \
228
+ -H "Content-Type: application/json"
229
+ ```
230
+
231
+ ## Evaluation
232
+
233
+ ### MTEB (Multilingual)
234
+
235
+ | Model | Size | Mean (Task) | Mean (Type) | Bitxt Mining | Class. | Clust. | Inst. Retri. | Multi. Class. | Pair. Class. | Rerank | Retri. | STS |
236
+ |----------------------------------|:-------:|:-------------:|:-------------:|:--------------:|:--------:|:--------:|:--------------:|:---------------:|:--------------:|:--------:|:--------:|:------:|
237
+ | NV-Embed-v2 | 7B | 56.29 | 49.58 | 57.84 | 57.29 | 40.80 | 1.04 | 18.63 | 78.94 | 63.82 | 56.72 | 71.10|
238
+ | GritLM-7B | 7B | 60.92 | 53.74 | 70.53 | 61.83 | 49.75 | 3.45 | 22.77 | 79.94 | 63.78 | 58.31 | 73.33|
239
+ | BGE-M3 | 0.6B | 59.56 | 52.18 | 79.11 | 60.35 | 40.88 | -3.11 | 20.1 | 80.76 | 62.79 | 54.60 | 74.12|
240
+ | multilingual-e5-large-instruct | 0.6B | 63.22 | 55.08 | 80.13 | 64.94 | 50.75 | -0.40 | 22.91 | 80.86 | 62.61 | 57.12 | 76.81|
241
+ | gte-Qwen2-1.5B-instruct | 1.5B | 59.45 | 52.69 | 62.51 | 58.32 | 52.05 | 0.74 | 24.02 | 81.58 | 62.58 | 60.78 | 71.61|
242
+ | gte-Qwen2-7b-Instruct | 7B | 62.51 | 55.93 | 73.92 | 61.55 | 52.77 | 4.94 | 25.48 | 85.13 | 65.55 | 60.08 | 73.98|
243
+ | text-embedding-3-large | - | 58.93 | 51.41 | 62.17 | 60.27 | 46.89 | -2.68 | 22.03 | 79.17 | 63.89 | 59.27 | 71.68|
244
+ | Cohere-embed-multilingual-v3.0 | - | 61.12 | 53.23 | 70.50 | 62.95 | 46.89 | -1.89 | 22.74 | 79.88 | 64.07 | 59.16 | 74.80|
245
+ | Gemini Embedding | - | 68.37 | 59.59 | 79.28 | 71.82 | 54.59 | 5.18 | **29.16** | 83.63 | 65.58 | 67.71 | 79.40|
246
+ | **Qwen3-Embedding-0.6B** | 0.6B | 64.33 | 56.00 | 72.22 | 66.83 | 52.33 | 5.09 | 24.59 | 80.83 | 61.41 | 64.64 | 76.17|
247
+ | **Qwen3-Embedding-4B** | 4B | 69.45 | 60.86 | 79.36 | 72.33 | 57.15 | **11.56** | 26.77 | 85.05 | 65.08 | 69.60 | 80.86|
248
+ | **Qwen3-Embedding-8B** | 8B | **70.58** | **61.69** | **80.89** | **74.00** | **57.65** | 10.06 | 28.66 | **86.40** | **65.63** | **70.88** | **81.08** |
249
+
250
+ > **Note**: For compared models, the scores are retrieved from MTEB online [leaderboard](https://huggingface.co/spaces/mteb/leaderboard) on May 24th, 2025.
251
+
252
+ ### MTEB (Eng v2)
253
+
254
+ | MTEB English / Models | Param. | Mean(Task) | Mean(Type) | Class. | Clust. | Pair Class. | Rerank. | Retri. | STS | Summ. |
255
+ |--------------------------------|:--------:|:------------:|:------------:|:--------:|:--------:|:-------------:|:---------:|:--------:|:-------:|:-------:|
256
+ | multilingual-e5-large-instruct | 0.6B | 65.53 | 61.21 | 75.54 | 49.89 | 86.24 | 48.74 | 53.47 | 84.72 | 29.89 |
257
+ | NV-Embed-v2 | 7.8B | 69.81 | 65.00 | 87.19 | 47.66 | 88.69 | 49.61 | 62.84 | 83.82 | 35.21 |
258
+ | GritLM-7B | 7.2B | 67.07 | 63.22 | 81.25 | 50.82 | 87.29 | 49.59 | 54.95 | 83.03 | 35.65 |
259
+ | gte-Qwen2-1.5B-instruct | 1.5B | 67.20 | 63.26 | 85.84 | 53.54 | 87.52 | 49.25 | 50.25 | 82.51 | 33.94 |
260
+ | stella_en_1.5B_v5 | 1.5B | 69.43 | 65.32 | 89.38 | 57.06 | 88.02 | 50.19 | 52.42 | 83.27 | 36.91 |
261
+ | gte-Qwen2-7B-instruct | 7.6B | 70.72 | 65.77 | 88.52 | 58.97 | 85.9 | 50.47 | 58.09 | 82.69 | 35.74 |
262
+ | gemini-embedding-exp-03-07 | - | 73.3 | 67.67 | 90.05 | 59.39 | 87.7 | 48.59 | 64.35 | 85.29 | 38.28 |
263
+ | **Qwen3-Embedding-0.6B** | 0.6B | 70.70 | 64.88 | 85.76 | 54.05 | 84.37 | 48.18 | 61.83 | 86.57 | 33.43 |
264
+ | **Qwen3-Embedding-4B** | 4B | 74.60 | 68.10 | 89.84 | 57.51 | 87.01 | 50.76 | 68.46 | 88.72 | 34.39 |
265
+ | **Qwen3-Embedding-8B** | 8B | 75.22 | 68.71 | 90.43 | 58.57 | 87.52 | 51.56 | 69.44 | 88.58 | 34.83 |
266
+
267
+ ### C-MTEB (MTEB Chinese)
268
+
269
+ | C-MTEB | Param. | Mean(Task) | Mean(Type) | Class. | Clust. | Pair Class. | Rerank. | Retr. | STS |
270
+ |------------------|--------|------------|------------|--------|--------|-------------|---------|-------|-------|
271
+ | multilingual-e5-large-instruct | 0.6B | 58.08 | 58.24 | 69.80 | 48.23 | 64.52 | 57.45 | 63.65 | 45.81 |
272
+ | bge-multilingual-gemma2 | 9B | 67.64 | 75.31 | 59.30 | 86.67 | 68.28 | 73.73 | 55.19 | - |
273
+ | gte-Qwen2-1.5B-instruct | 1.5B | 67.12 | 67.79 | 72.53 | 54.61 | 79.5 | 68.21 | 71.86 | 60.05 |
274
+ | gte-Qwen2-7B-instruct | 7.6B | 71.62 | 72.19 | 75.77 | 66.06 | 81.16 | 69.24 | 75.70 | 65.20 |
275
+ | ritrieve_zh_v1 | 0.3B | 72.71 | 73.85 | 76.88 | 66.5 | 85.98 | 72.86 | 76.97 | 63.92 |
276
+ | **Qwen3-Embedding-0.6B** | 0.6B | 66.33 | 67.45 | 71.40 | 68.74 | 76.42 | 62.58 | 71.03 | 54.52 |
277
+ | **Qwen3-Embedding-4B** | 4B | 72.27 | 73.51 | 75.46 | 77.89 | 83.34 | 66.05 | 77.03 | 61.26 |
278
+ | **Qwen3-Embedding-8B** | 8B | 73.84 | 75.00 | 76.97 | 80.08 | 84.23 | 66.99 | 78.21 | 63.53 |
279
+
280
+
281
+ ## Citation
282
+
283
+ If you find our work helpful, feel free to give us a cite.
284
+
285
+ ```
286
+ @article{qwen3embedding,
287
+ title={Qwen3 Embedding: Advancing Text Embedding and Reranking Through Foundation Models},
288
+ author={Zhang, Yanzhao and Li, Mingxin and Long, Dingkun and Zhang, Xin and Lin, Huan and Yang, Baosong and Xie, Pengjun and Yang, An and Liu, Dayiheng and Lin, Junyang and Huang, Fei and Zhou, Jingren},
289
+ journal={arXiv preprint arXiv:2506.05176},
290
+ year={2025}
291
+ }
292
+ ```
qwen3_embedding_0.6b_tokenizer/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151643,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "max_position_embeddings": 32768,
15
+ "max_window_layers": 28,
16
+ "model_type": "qwen3",
17
+ "num_attention_heads": 16,
18
+ "num_hidden_layers": 28,
19
+ "num_key_value_heads": 8,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "rope_theta": 1000000,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": true,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.51.3",
27
+ "use_cache": true,
28
+ "use_sliding_window": false,
29
+ "vocab_size": 151669
30
+ }
qwen3_embedding_0.6b_tokenizer/config_sentence_transformers.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompts": {
3
+ "query": "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery:",
4
+ "document": ""
5
+ },
6
+ "default_prompt_name": null,
7
+ "similarity_fn_name": "cosine"
8
+ }
qwen3_embedding_0.6b_tokenizer/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "eos_token_id": 151643,
4
+ "max_new_tokens": 2048,
5
+ "transformers_version": "4.51.3"
6
+ }
qwen3_embedding_0.6b_tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
qwen3_embedding_0.6b_tokenizer/modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]
qwen3_embedding_0.6b_tokenizer/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:def76fb086971c7867b829c23a26261e38d9d74e02139253b38aeb9df8b4b50a
3
+ size 11423705
qwen3_embedding_0.6b_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
231
+ "clean_up_tokenization_spaces": false,
232
+ "eos_token": "<|im_end|>",
233
+ "errors": "replace",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 131072,
236
+ "pad_token": "<|endoftext|>",
237
+ "split_special_tokens": false,
238
+ "tokenizer_class": "Qwen2Tokenizer",
239
+ "unk_token": null
240
+ }
qwen3_embedding_0.6b_tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
utils/infer_func.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import List, Tuple
4
+ from tqdm import tqdm
5
+ from axengine import InferenceSession
6
+ from ml_dtypes import bfloat16
7
+
8
+
9
+ class InferManager:
10
+ def __init__(self, config, model_dir):
11
+
12
+ self.config = config
13
+ self.max_seq_len = 2559
14
+
15
+ self.sub_dim = config.hidden_size // config.num_attention_heads if not config.head_dim else config.head_dim
16
+ self.kv_dim = self.sub_dim * config.num_key_value_heads
17
+
18
+ self.k_caches = [
19
+ np.zeros((1, self.max_seq_len, self.kv_dim), dtype=bfloat16)
20
+ for _ in range(config.num_hidden_layers)
21
+ ]
22
+ self.v_caches = [
23
+ np.zeros((1, self.max_seq_len, self.kv_dim), dtype=bfloat16)
24
+ for _ in range(config.num_hidden_layers)
25
+ ]
26
+
27
+ self.decoder_sessions = []
28
+ for layer_idx in tqdm(range(config.num_hidden_layers), desc="Init InferenceSession"):
29
+ session = InferenceSession(
30
+ f"{model_dir}/qwen3_p128_l{layer_idx}_together.axmodel"
31
+ )
32
+ self.decoder_sessions.append(session)
33
+ self.post_process_session = InferenceSession(
34
+ f"{model_dir}/qwen3_post.axmodel"
35
+ )
36
+ print("Model loaded successfully!")
37
+
38
+ @staticmethod
39
+ def _top_p(probs: np.ndarray, p: float) -> np.ndarray:
40
+ sorted_indices = np.argsort(probs)
41
+ filtered = probs.copy()
42
+ cumulative = 0
43
+ for idx in sorted_indices[::-1]:
44
+ if cumulative >= p:
45
+ filtered[idx] = 0
46
+ cumulative += filtered[idx]
47
+ return filtered / cumulative
48
+
49
+ @staticmethod
50
+ def _softmax(logits: np.ndarray) -> np.ndarray:
51
+ logits = logits - logits.max()
52
+ exp_logits = np.exp(logits)
53
+ return (exp_logits / np.sum(exp_logits)).astype(np.float64)
54
+
55
+ def post_process(self, logits, top_k=1, top_p=0.9, temperature=0.6):
56
+ logits = logits.astype(np.float32).flatten()
57
+ candidate_indices = np.argpartition(logits, -top_k)[-top_k:]
58
+ candidate_logits = logits[candidate_indices] / temperature
59
+ candidate_probs = self._softmax(candidate_logits)
60
+ candidate_probs = self._top_p(candidate_probs, top_p)
61
+ candidate_probs = candidate_probs.astype(np.float64) / candidate_probs.sum()
62
+ chosen_idx = np.random.multinomial(1, candidate_probs).argmax()
63
+ next_token = candidate_indices[chosen_idx]
64
+ return next_token, candidate_indices, candidate_probs
65
+
66
+ def gen_slice_indices(self, token_len, prefill=128, expand=128):
67
+ remaining = max(0, token_len - prefill)
68
+ extra_blocks = (remaining + expand - 1) // expand
69
+ return list(range(extra_blocks + 1))
70
+
71
+ def prefill(
72
+ self,
73
+ tokenizer,
74
+ token_ids,
75
+ embed_data,
76
+ slice_len=128,
77
+ ):
78
+ """
79
+ Prefill step for chunked inference.
80
+ """
81
+ seq_len = len(token_ids)
82
+ slice_indices = [i for i in range(seq_len // slice_len + 1)]
83
+ print(f"slice_indices: {slice_indices}")
84
+ # total_prefill_len = (
85
+ # slice_len * slice_indices[-1]
86
+ # if slice_indices[-1] != 0
87
+ # else slice_len
88
+ # )
89
+ total_prefill_len = slice_len * (slice_indices[-1] + 1)
90
+ # slice_indices = self.gen_slice_indices(seq_len)
91
+ # import pdb; pdb.set_trace()
92
+
93
+ if total_prefill_len > 0:
94
+ for slice_idx in slice_indices:
95
+ indices = np.arange(
96
+ slice_idx * slice_len,
97
+ (slice_idx + 1) * slice_len,
98
+ dtype=np.uint32
99
+ ).reshape((1, slice_len))
100
+
101
+ mask = (
102
+ np.zeros((1, slice_len, slice_len * (slice_idx + 1)))
103
+ - 65536
104
+ )
105
+ data = np.zeros((1, slice_len, self.config.hidden_size)).astype(bfloat16)
106
+ for i, t in enumerate(
107
+ range(
108
+ slice_idx * slice_len,
109
+ (slice_idx + 1) * slice_len,
110
+ )
111
+ ):
112
+ if t < len(token_ids):
113
+ mask[:, i, : slice_idx * slice_len + i + 1] = 0
114
+ data[:, i : i + 1, :] = (
115
+ embed_data[t]
116
+ .reshape((1, 1, self.config.hidden_size))
117
+ .astype(bfloat16)
118
+ )
119
+
120
+ remain_len = (
121
+ seq_len - slice_idx * slice_len
122
+ if slice_idx == slice_indices[-1]
123
+ else slice_len
124
+ )
125
+ mask = mask.astype(bfloat16)
126
+ for layer_idx in range(self.config.num_hidden_layers):
127
+ input_feed = {
128
+ "K_cache": (
129
+ self.k_caches[layer_idx][:, 0 : slice_len * slice_idx, :]
130
+ if slice_idx
131
+ else np.zeros((1, 1, self.config.hidden_size), dtype=bfloat16)
132
+ ),
133
+ "V_cache": (
134
+ self.v_caches[layer_idx][:, 0 : slice_len * slice_idx, :]
135
+ if slice_idx
136
+ else np.zeros((1, 1, self.config.hidden_size), dtype=bfloat16)
137
+ ),
138
+ "indices": indices,
139
+ "input": data,
140
+ "mask": mask,
141
+ }
142
+ outputs = self.decoder_sessions[layer_idx].run(None, input_feed, shape_group=slice_idx + 1)
143
+ self.k_caches[layer_idx][
144
+ :,
145
+ slice_idx * slice_len : slice_idx * slice_len + remain_len,
146
+ :,
147
+ ] = outputs[0][:, :remain_len, :]
148
+ self.v_caches[layer_idx][
149
+ :,
150
+ slice_idx * slice_len : slice_idx * slice_len + remain_len,
151
+ :,
152
+ ] = outputs[1][:, :remain_len, :]
153
+ data = outputs[2]
154
+
155
+ print("Slice prefill done:", slice_idx)
156
+
157
+ return data[:, :remain_len, :]
158
+ # post_out = self.post_process_session.run(
159
+ # None,
160
+ # {
161
+ # "input": data[
162
+ # :, seq_len - (len(slice_indices) - 1) * slice_len - 1, None, :
163
+ # ]
164
+ # }
165
+ # )[0]
166
+ # next_token, possible_tokens, possible_probs = self.post_process(post_out)
167
+ # possible_decoded = [tokenizer.decode([t]) for t in possible_tokens]
168
+ # possible_probs_str = [str((t, p)) for t, p in zip(possible_decoded, possible_probs)]
169
+ # token_ids.append(next_token)
170
+ # return token_ids
171
+
172
+ def decode(
173
+ self,
174
+ tokenizer,
175
+ token_ids,
176
+ embed_matrix,
177
+ prefill_len=128,
178
+ slice_len=128
179
+ ):
180
+ # import pdb; pdb.set_trace()
181
+ print("answer >>", tokenizer.decode(token_ids[-1], skip_special_tokens=True), end='', flush=True)
182
+ self.max_seq_len = 2559
183
+ mask = np.zeros((1, 1, self.max_seq_len + 1), dtype=np.float32).astype(bfloat16)
184
+ mask[:, :, :self.max_seq_len] -= 65536
185
+ seq_len = len(token_ids) - 1
186
+ if prefill_len > 0:
187
+ mask[:, :, :seq_len] = 0
188
+ for step_idx in range(self.max_seq_len):
189
+ if prefill_len > 0 and step_idx < seq_len:
190
+ continue
191
+ # import pdb; pdb.set_trace()
192
+ cur_token = token_ids[step_idx]
193
+ indices = np.array([step_idx], np.uint32).reshape((1, 1))
194
+ data = embed_matrix[cur_token, :].reshape((1, 1, self.config.hidden_size)).astype(bfloat16)
195
+ for layer_idx in range(self.config.num_hidden_layers):
196
+ input_feed = {
197
+ "K_cache": self.k_caches[layer_idx],
198
+ "V_cache": self.v_caches[layer_idx],
199
+ "indices": indices,
200
+ "input": data,
201
+ "mask": mask,
202
+ }
203
+ outputs = self.decoder_sessions[layer_idx].run(None, input_feed, shape_group=0)
204
+ self.k_caches[layer_idx][:, step_idx, :] = outputs[0][:, :, :]
205
+ self.v_caches[layer_idx][:, step_idx, :] = outputs[1][:, :, :]
206
+ data = outputs[2]
207
+ mask[..., step_idx] = 0
208
+ if step_idx < seq_len - 1:
209
+ continue
210
+ else:
211
+ post_out = self.post_process_session.run(None, {"input": data})[0]
212
+ next_token, possible_tokens, possible_probs = self.post_process(post_out)
213
+ token_ids.append(next_token)
214
+ if next_token == tokenizer.eos_token_id and next_token > seq_len:
215
+ break
216
+ print(tokenizer.decode(next_token, skip_special_tokens=True), end='', flush=True)
217
+