yongqiang
commited on
Commit
·
04527ce
0
Parent(s):
Initialize this repo
Browse files- .gitattributes +37 -0
- README.md +90 -0
- config.json +0 -0
- gemma3_axmodel/gemma3_text_p128_l0_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l10_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l11_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l12_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l13_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l14_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l15_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l16_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l17_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l18_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l19_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l1_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l20_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l21_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l22_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l23_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l24_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l25_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l2_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l3_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l4_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l5_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l6_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l7_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l8_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_p128_l9_together.axmodel +3 -0
- gemma3_axmodel/gemma3_text_post.axmodel +3 -0
- gemma3_axmodel/model.embed_tokens.weight.bfloat16.bin +3 -0
- gemma3_axmodel/model.embed_tokens.weight.npy +3 -0
- gemma3_tokenizer/.gitattributes +38 -0
- gemma3_tokenizer/README.md +514 -0
- gemma3_tokenizer/added_tokens.json +3 -0
- gemma3_tokenizer/config.json +37 -0
- gemma3_tokenizer/generation_config.json +13 -0
- gemma3_tokenizer/model.safetensors +3 -0
- gemma3_tokenizer/special_tokens_map.json +33 -0
- gemma3_tokenizer/tokenizer.json +3 -0
- gemma3_tokenizer/tokenizer.model +3 -0
- gemma3_tokenizer/tokenizer_config.json +0 -0
- infer_axmodel.py +78 -0
- utils/infer_func.py +271 -0
.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.axmodel filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: bsd-3-clause
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
- zh
|
| 6 |
+
base_model:
|
| 7 |
+
- HuggingFaceTB/Gemma-3-1B-it
|
| 8 |
+
pipeline_tag: Text Generation
|
| 9 |
+
tags:
|
| 10 |
+
- HuggingFaceTB
|
| 11 |
+
- Gemma-3-1B-it
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Gemma-3-1B-it-Int8
|
| 15 |
+
|
| 16 |
+
This version of Gemma-3-1B-it has been converted to run on the Axera NPU using **w8a16** quantization.
|
| 17 |
+
|
| 18 |
+
Compatible with Pulsar2 version: 4.0
|
| 19 |
+
|
| 20 |
+
## Convert tools links:
|
| 21 |
+
|
| 22 |
+
For those who are interested in model conversion, you can try to export axmodel through the original repo:
|
| 23 |
+
- https://huggingface.co/HuggingFaceTB/Gemma-3-1B-it
|
| 24 |
+
|
| 25 |
+
- [Github for gemma-3-1B-it.axera](https://github.com/AXERA-TECH/gemma-3-1B-it.axera)
|
| 26 |
+
|
| 27 |
+
- [Pulsar2 Link, How to Convert LLM from Huggingface to axmodel](https://pulsar2-docs.readthedocs.io/en/latest/appendix/build_llm.html)
|
| 28 |
+
|
| 29 |
+
## Support Platform
|
| 30 |
+
- AX650
|
| 31 |
+
- [M4N-Dock(爱芯派Pro)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
|
| 32 |
+
|
| 33 |
+
<!-- ## TODO Model infer time -->
|
| 34 |
+
|
| 35 |
+
## How to use
|
| 36 |
+
|
| 37 |
+
Download all files from this repository to the device.
|
| 38 |
+
|
| 39 |
+
**Using AX650 Board**
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
ai@ai-bj ~/yongqiang/Gemma-3-1B-it $ tree -L 1
|
| 43 |
+
.
|
| 44 |
+
├── config.json
|
| 45 |
+
├── gemma3_axmodel
|
| 46 |
+
├── gemma3_tokenizer
|
| 47 |
+
├── infer_axmodel.py
|
| 48 |
+
├── README.md
|
| 49 |
+
└── utils
|
| 50 |
+
|
| 51 |
+
3 directories, 3 files
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
#### Inference with AX650 Host, such as M4N-Dock(爱芯派Pro) or AX650N DEMO Board
|
| 55 |
+
|
| 56 |
+
**Text Generation**
|
| 57 |
+
|
| 58 |
+
input text:
|
| 59 |
+
|
| 60 |
+
```sh
|
| 61 |
+
$ python3 infer_axmodel.py -q "请用中文介绍一下你自己."
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
log information:
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
[INFO] Compiler version: 5.0-patch1-dirty 93949955-dirty
|
| 68 |
+
Init InferenceSession: 100%|██████████████████████████████████████████████████████████| 26/26 [00:21<00:00, 1.18it/s]
|
| 69 |
+
[INFO] Using provider: AxEngineExecutionProvider
|
| 70 |
+
[INFO] Model type: 2 (triple core)
|
| 71 |
+
[INFO] Compiler version: 5.0-patch1-dirty 93949955-dirty
|
| 72 |
+
Model loaded successfully!
|
| 73 |
+
slice_indices: [0]
|
| 74 |
+
Slice prefill done: 0
|
| 75 |
+
answer >> 您好!我是一个大型语言模型,由 Google 训练。
|
| 76 |
+
|
| 77 |
+
简单来说,我可以帮你做很多事情,比如:
|
| 78 |
+
|
| 79 |
+
* **回答你的问题:** 无论你问什么,我都会尽力用清晰、准确的语言来回答。
|
| 80 |
+
* **生成文本:** 比如写诗歌、故事、邮件、代码等等。
|
| 81 |
+
* **翻译语言:** 我可以将一种语言翻译成另一种语言。
|
| 82 |
+
* **总结文本:** 我可以帮你快速阅读一段文字,提取关键信息。
|
| 83 |
+
* **进行创意写作:** 我们可以一起头脑风暴,一起创作故事或文章。
|
| 84 |
+
|
| 85 |
+
我还在不断学习和进步,所以我的能力也在不断提升。
|
| 86 |
+
|
| 87 |
+
我是一个工具,可以帮助你,但不能代替人类的思考和判断。
|
| 88 |
+
|
| 89 |
+
希望我能帮到你^@! 你有什么想问的或者想让我做什么吗? 😊
|
| 90 |
+
```
|
config.json
ADDED
|
File without changes
|
gemma3_axmodel/gemma3_text_p128_l0_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:170750a73f52c06144d764196784d731f21d1186698bc8e5d11aa927c2a80a1c
|
| 3 |
+
size 49204919
|
gemma3_axmodel/gemma3_text_p128_l10_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5493aa669c0bcd206ba984b18af3a69e4777eb36e0bbfa9e60510b4240e5ca6d
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l11_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:76bb18ef9f01fd7eb0b83e7ff8a3fbb380608845a591d3883fafc7ab735326e5
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l12_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78b16ae6f1da5e7920cbaeb8b81e7f820e651a1f121aaf24a68b9849112bf775
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l13_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de46433301abf71a9f589d3aba31f66c57cb0835db1587aedf94502c1c43b1e3
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l14_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7c95dd1fc1ed49968a741b9a2b3d7087c2399f0ed1e3d9576691dfae77bc991
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l15_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dac9bcadbc75d970fdf6732a9ef39b4d56a8bdd2563e486194e7d834ba57b2c0
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l16_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:188b7b5297ace549ecd9ba625339399ae8183f1ffb074d5f9d20fe04d76c610b
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l17_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ffe76e6f928f15cfb6cb1a94adc10e2616ba554e5d10a98bee555bb799790aa9
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l18_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:efa5b71ba7207aa04ec3bb4f22817a1e72eb9dabf5381ca5790380a44fa26384
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l19_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02c8821748ab677842c87488c6d23f45904402ea3b06850c28405163c682ec76
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l1_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c6d3b9494fd65894d66ad054e3cbaec3a861a1567234c235b4b056c5979069a6
|
| 3 |
+
size 49171287
|
gemma3_axmodel/gemma3_text_p128_l20_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d6813f8329752059f568b66cff66fb2348e5e6cb246a2819330a8d2825b566ae
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l21_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2bae38bb4f567b2e36f4e87d00a0dad232053b7c0590e4b95708382a355f321c
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l22_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5b5303007da3f87badca09aac893debdd6799c94884e4980b487db0dea9364d
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l23_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e250358450e6e65afc6a3f817313c003831a982e07b9d923926253b0d54115f
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l24_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:36a0e3ee62e4a48ca8f3e135219ff5ecddddf9f4f2f161c1fdb6c14489a0f5ce
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l25_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6437d6f987ce165fdd71b3469b55c344ab262f0c7f4e03e2aeaee357ab302b5d
|
| 3 |
+
size 49171295
|
gemma3_axmodel/gemma3_text_p128_l2_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9ea4e3a444073f25769587176d6522bda9b3fbfd52afcb864c67c0b58f35e8d5
|
| 3 |
+
size 49171287
|
gemma3_axmodel/gemma3_text_p128_l3_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3ddc3af4ecef0828c2b7996bb44ff29f4e17b2d5e82bd3624426d48d21e5764b
|
| 3 |
+
size 49171287
|
gemma3_axmodel/gemma3_text_p128_l4_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68b22ec763df9b6b51121e61024316ff4ed19b12bf2ea876352273877d001a0b
|
| 3 |
+
size 49171287
|
gemma3_axmodel/gemma3_text_p128_l5_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3e963ea9006e1323271d271475b9f9e412cef7d9282100472dac2df6b313adfb
|
| 3 |
+
size 49171287
|
gemma3_axmodel/gemma3_text_p128_l6_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c4d669898cf4abf43063257b07f005d5634aa64fa094c3e748b23b7f205c7f8
|
| 3 |
+
size 49171287
|
gemma3_axmodel/gemma3_text_p128_l7_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a9e4164dc6399f4e9fd53023c60e4b0cf4ab71d36440353d30cf36add0cac684
|
| 3 |
+
size 49171287
|
gemma3_axmodel/gemma3_text_p128_l8_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3531176604bf11d6a40399cb4aa4bcdc1937535c6b9df02d6874a1a7500aa210
|
| 3 |
+
size 49171287
|
gemma3_axmodel/gemma3_text_p128_l9_together.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:156c73054276cc93e16d4e22868631b28939b5adae65bb96c618bedf2afed844
|
| 3 |
+
size 49171287
|
gemma3_axmodel/gemma3_text_post.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ee8d28a0110d6d079a5299758b07ff7361ebd29dbbd6f29343705d2be1eb635
|
| 3 |
+
size 335820123
|
gemma3_axmodel/model.embed_tokens.weight.bfloat16.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8169474a8e8d2df7cc2ab19a5dd6aed0c2fe51afa2d9bfff2fc5235eea2bae26
|
| 3 |
+
size 603979776
|
gemma3_axmodel/model.embed_tokens.weight.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:537d4d5f2547a5382fa6064bf9da9017e09c60c38eca0317bb80a8479124f8c1
|
| 3 |
+
size 1207959680
|
gemma3_tokenizer/.gitattributes
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
tokenizer.model filter=lfs diff=lfs merge=lfs -text
|
gemma3_tokenizer/README.md
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: gemma
|
| 3 |
+
library_name: transformers
|
| 4 |
+
pipeline_tag: text-generation
|
| 5 |
+
extra_gated_heading: Access Gemma on Hugging Face
|
| 6 |
+
extra_gated_prompt: To access Gemma on Hugging Face, you’re required to review and
|
| 7 |
+
agree to Google’s usage license. To do this, please ensure you’re logged in to Hugging
|
| 8 |
+
Face and click below. Requests are processed immediately.
|
| 9 |
+
extra_gated_button_content: Acknowledge license
|
| 10 |
+
base_model: google/gemma-3-1b-pt
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Gemma 3 model card
|
| 14 |
+
|
| 15 |
+
**Model Page**: [Gemma](https://ai.google.dev/gemma/docs/core)
|
| 16 |
+
|
| 17 |
+
**Resources and Technical Documentation**:
|
| 18 |
+
|
| 19 |
+
* [Gemma 3 Technical Report][g3-tech-report]
|
| 20 |
+
* [Responsible Generative AI Toolkit][rai-toolkit]
|
| 21 |
+
* [Gemma on Kaggle][kaggle-gemma]
|
| 22 |
+
* [Gemma on Vertex Model Garden][vertex-mg-gemma3]
|
| 23 |
+
|
| 24 |
+
**Terms of Use**: [Terms][terms]
|
| 25 |
+
|
| 26 |
+
**Authors**: Google DeepMind
|
| 27 |
+
|
| 28 |
+
## Model Information
|
| 29 |
+
|
| 30 |
+
Summary description and brief definition of inputs and outputs.
|
| 31 |
+
|
| 32 |
+
### Description
|
| 33 |
+
|
| 34 |
+
Gemma is a family of lightweight, state-of-the-art open models from Google,
|
| 35 |
+
built from the same research and technology used to create the Gemini models.
|
| 36 |
+
Gemma 3 models are multimodal, handling text and image input and generating text
|
| 37 |
+
output, with open weights for both pre-trained variants and instruction-tuned
|
| 38 |
+
variants. Gemma 3 has a large, 128K context window, multilingual support in over
|
| 39 |
+
140 languages, and is available in more sizes than previous versions. Gemma 3
|
| 40 |
+
models are well-suited for a variety of text generation and image understanding
|
| 41 |
+
tasks, including question answering, summarization, and reasoning. Their
|
| 42 |
+
relatively small size makes it possible to deploy them in environments with
|
| 43 |
+
limited resources such as laptops, desktops or your own cloud infrastructure,
|
| 44 |
+
democratizing access to state of the art AI models and helping foster innovation
|
| 45 |
+
for everyone.
|
| 46 |
+
|
| 47 |
+
### Inputs and outputs
|
| 48 |
+
|
| 49 |
+
- **Input:**
|
| 50 |
+
- Text string, such as a question, a prompt, or a document to be summarized
|
| 51 |
+
- Images, normalized to 896 x 896 resolution and encoded to 256 tokens
|
| 52 |
+
each
|
| 53 |
+
- Total input context of 128K tokens for the 4B, 12B, and 27B sizes, and
|
| 54 |
+
32K tokens for the 1B size
|
| 55 |
+
|
| 56 |
+
- **Output:**
|
| 57 |
+
- Generated text in response to the input, such as an answer to a
|
| 58 |
+
question, analysis of image content, or a summary of a document
|
| 59 |
+
- Total output context of 8192 tokens
|
| 60 |
+
|
| 61 |
+
### Usage
|
| 62 |
+
|
| 63 |
+
Below, there are some code snippets on how to get quickly started with running the model. First, install the Transformers library. Gemma 3 is supported starting from transformers 4.50.0.
|
| 64 |
+
|
| 65 |
+
```sh
|
| 66 |
+
$ pip install -U transformers
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Then, copy the snippet from the section that is relevant for your use case.
|
| 70 |
+
|
| 71 |
+
#### Running with the `pipeline` API
|
| 72 |
+
|
| 73 |
+
With instruction-tuned models, you need to use chat templates to process our inputs first. Then, you can pass it to the pipeline.
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
from transformers import pipeline
|
| 77 |
+
import torch
|
| 78 |
+
|
| 79 |
+
pipe = pipeline("text-generation", model="google/gemma-3-1b-it", device="cuda", torch_dtype=torch.bfloat16)
|
| 80 |
+
|
| 81 |
+
messages = [
|
| 82 |
+
[
|
| 83 |
+
{
|
| 84 |
+
"role": "system",
|
| 85 |
+
"content": [{"type": "text", "text": "You are a helpful assistant."},]
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"role": "user",
|
| 89 |
+
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"},]
|
| 90 |
+
},
|
| 91 |
+
],
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
output = pipe(messages, max_new_tokens=50)
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
#### Running the model on a single / multi GPU
|
| 98 |
+
|
| 99 |
+
```python
|
| 100 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, Gemma3ForCausalLM
|
| 101 |
+
import torch
|
| 102 |
+
|
| 103 |
+
model_id = "google/gemma-3-1b-it"
|
| 104 |
+
|
| 105 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 106 |
+
|
| 107 |
+
model = Gemma3ForCausalLM.from_pretrained(
|
| 108 |
+
model_id, quantization_config=quantization_config
|
| 109 |
+
).eval()
|
| 110 |
+
|
| 111 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 112 |
+
|
| 113 |
+
messages = [
|
| 114 |
+
[
|
| 115 |
+
{
|
| 116 |
+
"role": "system",
|
| 117 |
+
"content": [{"type": "text", "text": "You are a helpful assistant."},]
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"role": "user",
|
| 121 |
+
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"},]
|
| 122 |
+
},
|
| 123 |
+
],
|
| 124 |
+
]
|
| 125 |
+
inputs = tokenizer.apply_chat_template(
|
| 126 |
+
messages,
|
| 127 |
+
add_generation_prompt=True,
|
| 128 |
+
tokenize=True,
|
| 129 |
+
return_dict=True,
|
| 130 |
+
return_tensors="pt",
|
| 131 |
+
).to(model.device).to(torch.bfloat16)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
with torch.inference_mode():
|
| 135 |
+
outputs = model.generate(**inputs, max_new_tokens=64)
|
| 136 |
+
|
| 137 |
+
outputs = tokenizer.batch_decode(outputs)
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
### Citation
|
| 142 |
+
|
| 143 |
+
```none
|
| 144 |
+
@article{gemma_2025,
|
| 145 |
+
title={Gemma 3},
|
| 146 |
+
url={https://goo.gle/Gemma3Report},
|
| 147 |
+
publisher={Kaggle},
|
| 148 |
+
author={Gemma Team},
|
| 149 |
+
year={2025}
|
| 150 |
+
}
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## Model Data
|
| 154 |
+
|
| 155 |
+
Data used for model training and how the data was processed.
|
| 156 |
+
|
| 157 |
+
### Training Dataset
|
| 158 |
+
|
| 159 |
+
These models were trained on a dataset of text data that includes a wide variety
|
| 160 |
+
of sources. The 27B model was trained with 14 trillion tokens, the 12B model was
|
| 161 |
+
trained with 12 trillion tokens, 4B model was trained with 4 trillion tokens and
|
| 162 |
+
1B with 2 trillion tokens. Here are the key components:
|
| 163 |
+
|
| 164 |
+
- Web Documents: A diverse collection of web text ensures the model is
|
| 165 |
+
exposed to a broad range of linguistic styles, topics, and vocabulary. The
|
| 166 |
+
training dataset includes content in over 140 languages.
|
| 167 |
+
- Code: Exposing the model to code helps it to learn the syntax and
|
| 168 |
+
patterns of programming languages, which improves its ability to generate
|
| 169 |
+
code and understand code-related questions.
|
| 170 |
+
- Mathematics: Training on mathematical text helps the model learn logical
|
| 171 |
+
reasoning, symbolic representation, and to address mathematical queries.
|
| 172 |
+
- Images: A wide range of images enables the model to perform image
|
| 173 |
+
analysis and visual data extraction tasks.
|
| 174 |
+
|
| 175 |
+
The combination of these diverse data sources is crucial for training a powerful
|
| 176 |
+
multimodal model that can handle a wide variety of different tasks and data
|
| 177 |
+
formats.
|
| 178 |
+
|
| 179 |
+
### Data Preprocessing
|
| 180 |
+
|
| 181 |
+
Here are the key data cleaning and filtering methods applied to the training
|
| 182 |
+
data:
|
| 183 |
+
|
| 184 |
+
- CSAM Filtering: Rigorous CSAM (Child Sexual Abuse Material) filtering
|
| 185 |
+
was applied at multiple stages in the data preparation process to ensure
|
| 186 |
+
the exclusion of harmful and illegal content.
|
| 187 |
+
- Sensitive Data Filtering: As part of making Gemma pre-trained models
|
| 188 |
+
safe and reliable, automated techniques were used to filter out certain
|
| 189 |
+
personal information and other sensitive data from training sets.
|
| 190 |
+
- Additional methods: Filtering based on content quality and safety in
|
| 191 |
+
line with [our policies][safety-policies].
|
| 192 |
+
|
| 193 |
+
## Implementation Information
|
| 194 |
+
|
| 195 |
+
Details about the model internals.
|
| 196 |
+
|
| 197 |
+
### Hardware
|
| 198 |
+
|
| 199 |
+
Gemma was trained using [Tensor Processing Unit (TPU)][tpu] hardware (TPUv4p,
|
| 200 |
+
TPUv5p and TPUv5e). Training vision-language models (VLMS) requires significant
|
| 201 |
+
computational power. TPUs, designed specifically for matrix operations common in
|
| 202 |
+
machine learning, offer several advantages in this domain:
|
| 203 |
+
|
| 204 |
+
- Performance: TPUs are specifically designed to handle the massive
|
| 205 |
+
computations involved in training VLMs. They can speed up training
|
| 206 |
+
considerably compared to CPUs.
|
| 207 |
+
- Memory: TPUs often come with large amounts of high-bandwidth memory,
|
| 208 |
+
allowing for the handling of large models and batch sizes during training.
|
| 209 |
+
This can lead to better model quality.
|
| 210 |
+
- Scalability: TPU Pods (large clusters of TPUs) provide a scalable
|
| 211 |
+
solution for handling the growing complexity of large foundation models.
|
| 212 |
+
You can distribute training across multiple TPU devices for faster and more
|
| 213 |
+
efficient processing.
|
| 214 |
+
- Cost-effectiveness: In many scenarios, TPUs can provide a more
|
| 215 |
+
cost-effective solution for training large models compared to CPU-based
|
| 216 |
+
infrastructure, especially when considering the time and resources saved
|
| 217 |
+
due to faster training.
|
| 218 |
+
- These advantages are aligned with
|
| 219 |
+
[Google's commitments to operate sustainably][sustainability].
|
| 220 |
+
|
| 221 |
+
### Software
|
| 222 |
+
|
| 223 |
+
Training was done using [JAX][jax] and [ML Pathways][ml-pathways].
|
| 224 |
+
|
| 225 |
+
JAX allows researchers to take advantage of the latest generation of hardware,
|
| 226 |
+
including TPUs, for faster and more efficient training of large models. ML
|
| 227 |
+
Pathways is Google's latest effort to build artificially intelligent systems
|
| 228 |
+
capable of generalizing across multiple tasks. This is specially suitable for
|
| 229 |
+
foundation models, including large language models like these ones.
|
| 230 |
+
|
| 231 |
+
Together, JAX and ML Pathways are used as described in the
|
| 232 |
+
[paper about the Gemini family of models][gemini-2-paper]; *"the 'single
|
| 233 |
+
controller' programming model of Jax and Pathways allows a single Python
|
| 234 |
+
process to orchestrate the entire training run, dramatically simplifying the
|
| 235 |
+
development workflow."*
|
| 236 |
+
|
| 237 |
+
## Evaluation
|
| 238 |
+
|
| 239 |
+
Model evaluation metrics and results.
|
| 240 |
+
|
| 241 |
+
### Benchmark Results
|
| 242 |
+
|
| 243 |
+
These models were evaluated against a large collection of different datasets and
|
| 244 |
+
metrics to cover different aspects of text generation:
|
| 245 |
+
|
| 246 |
+
#### Reasoning and factuality
|
| 247 |
+
|
| 248 |
+
| Benchmark | Metric | Gemma 3 PT 1B | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
| 249 |
+
| ------------------------------ |----------------|:--------------:|:-------------:|:--------------:|:--------------:|
|
| 250 |
+
| [HellaSwag][hellaswag] | 10-shot | 62.3 | 77.2 | 84.2 | 85.6 |
|
| 251 |
+
| [BoolQ][boolq] | 0-shot | 63.2 | 72.3 | 78.8 | 82.4 |
|
| 252 |
+
| [PIQA][piqa] | 0-shot | 73.8 | 79.6 | 81.8 | 83.3 |
|
| 253 |
+
| [SocialIQA][socialiqa] | 0-shot | 48.9 | 51.9 | 53.4 | 54.9 |
|
| 254 |
+
| [TriviaQA][triviaqa] | 5-shot | 39.8 | 65.8 | 78.2 | 85.5 |
|
| 255 |
+
| [Natural Questions][naturalq] | 5-shot | 9.48 | 20.0 | 31.4 | 36.1 |
|
| 256 |
+
| [ARC-c][arc] | 25-shot | 38.4 | 56.2 | 68.9 | 70.6 |
|
| 257 |
+
| [ARC-e][arc] | 0-shot | 73.0 | 82.4 | 88.3 | 89.0 |
|
| 258 |
+
| [WinoGrande][winogrande] | 5-shot | 58.2 | 64.7 | 74.3 | 78.8 |
|
| 259 |
+
| [BIG-Bench Hard][bbh] | few-shot | 28.4 | 50.9 | 72.6 | 77.7 |
|
| 260 |
+
| [DROP][drop] | 1-shot | 42.4 | 60.1 | 72.2 | 77.2 |
|
| 261 |
+
|
| 262 |
+
[hellaswag]: https://arxiv.org/abs/1905.07830
|
| 263 |
+
[boolq]: https://arxiv.org/abs/1905.10044
|
| 264 |
+
[piqa]: https://arxiv.org/abs/1911.11641
|
| 265 |
+
[socialiqa]: https://arxiv.org/abs/1904.09728
|
| 266 |
+
[triviaqa]: https://arxiv.org/abs/1705.03551
|
| 267 |
+
[naturalq]: https://github.com/google-research-datasets/natural-questions
|
| 268 |
+
[arc]: https://arxiv.org/abs/1911.01547
|
| 269 |
+
[winogrande]: https://arxiv.org/abs/1907.10641
|
| 270 |
+
[bbh]: https://paperswithcode.com/dataset/bbh
|
| 271 |
+
[drop]: https://arxiv.org/abs/1903.00161
|
| 272 |
+
|
| 273 |
+
#### STEM and code
|
| 274 |
+
|
| 275 |
+
| Benchmark | Metric | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
| 276 |
+
| ------------------------------ |----------------|:-------------:|:--------------:|:--------------:|
|
| 277 |
+
| [MMLU][mmlu] | 5-shot | 59.6 | 74.5 | 78.6 |
|
| 278 |
+
| [MMLU][mmlu] (Pro COT) | 5-shot | 29.2 | 45.3 | 52.2 |
|
| 279 |
+
| [AGIEval][agieval] | 3-5-shot | 42.1 | 57.4 | 66.2 |
|
| 280 |
+
| [MATH][math] | 4-shot | 24.2 | 43.3 | 50.0 |
|
| 281 |
+
| [GSM8K][gsm8k] | 8-shot | 38.4 | 71.0 | 82.6 |
|
| 282 |
+
| [GPQA][gpqa] | 5-shot | 15.0 | 25.4 | 24.3 |
|
| 283 |
+
| [MBPP][mbpp] | 3-shot | 46.0 | 60.4 | 65.6 |
|
| 284 |
+
| [HumanEval][humaneval] | 0-shot | 36.0 | 45.7 | 48.8 |
|
| 285 |
+
|
| 286 |
+
[mmlu]: https://arxiv.org/abs/2009.03300
|
| 287 |
+
[agieval]: https://arxiv.org/abs/2304.06364
|
| 288 |
+
[math]: https://arxiv.org/abs/2103.03874
|
| 289 |
+
[gsm8k]: https://arxiv.org/abs/2110.14168
|
| 290 |
+
[gpqa]: https://arxiv.org/abs/2311.12022
|
| 291 |
+
[mbpp]: https://arxiv.org/abs/2108.07732
|
| 292 |
+
[humaneval]: https://arxiv.org/abs/2107.03374
|
| 293 |
+
|
| 294 |
+
#### Multilingual
|
| 295 |
+
|
| 296 |
+
| Benchmark | Gemma 3 PT 1B | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
| 297 |
+
| ------------------------------------ |:-------------:|:-------------:|:--------------:|:--------------:|
|
| 298 |
+
| [MGSM][mgsm] | 2.04 | 34.7 | 64.3 | 74.3 |
|
| 299 |
+
| [Global-MMLU-Lite][global-mmlu-lite] | 24.9 | 57.0 | 69.4 | 75.7 |
|
| 300 |
+
| [WMT24++][wmt24pp] (ChrF) | 36.7 | 48.4 | 53.9 | 55.7 |
|
| 301 |
+
| [FloRes][flores] | 29.5 | 39.2 | 46.0 | 48.8 |
|
| 302 |
+
| [XQuAD][xquad] (all) | 43.9 | 68.0 | 74.5 | 76.8 |
|
| 303 |
+
| [ECLeKTic][eclektic] | 4.69 | 11.0 | 17.2 | 24.4 |
|
| 304 |
+
| [IndicGenBench][indicgenbench] | 41.4 | 57.2 | 61.7 | 63.4 |
|
| 305 |
+
|
| 306 |
+
[mgsm]: https://arxiv.org/abs/2210.03057
|
| 307 |
+
[flores]: https://arxiv.org/abs/2106.03193
|
| 308 |
+
[xquad]: https://arxiv.org/abs/1910.11856v3
|
| 309 |
+
[global-mmlu-lite]: https://huggingface.co/datasets/CohereForAI/Global-MMLU-Lite
|
| 310 |
+
[wmt24pp]: https://arxiv.org/abs/2502.12404v1
|
| 311 |
+
[eclektic]: https://arxiv.org/abs/2502.21228
|
| 312 |
+
[indicgenbench]: https://arxiv.org/abs/2404.16816
|
| 313 |
+
|
| 314 |
+
#### Multimodal
|
| 315 |
+
|
| 316 |
+
| Benchmark | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
| 317 |
+
| ------------------------------ |:-------------:|:--------------:|:--------------:|
|
| 318 |
+
| [COCOcap][coco-cap] | 102 | 111 | 116 |
|
| 319 |
+
| [DocVQA][docvqa] (val) | 72.8 | 82.3 | 85.6 |
|
| 320 |
+
| [InfoVQA][info-vqa] (val) | 44.1 | 54.8 | 59.4 |
|
| 321 |
+
| [MMMU][mmmu] (pt) | 39.2 | 50.3 | 56.1 |
|
| 322 |
+
| [TextVQA][textvqa] (val) | 58.9 | 66.5 | 68.6 |
|
| 323 |
+
| [RealWorldQA][realworldqa] | 45.5 | 52.2 | 53.9 |
|
| 324 |
+
| [ReMI][remi] | 27.3 | 38.5 | 44.8 |
|
| 325 |
+
| [AI2D][ai2d] | 63.2 | 75.2 | 79.0 |
|
| 326 |
+
| [ChartQA][chartqa] | 63.6 | 74.7 | 76.3 |
|
| 327 |
+
| [VQAv2][vqav2] | 63.9 | 71.2 | 72.9 |
|
| 328 |
+
| [BLINK][blinkvqa] | 38.0 | 35.9 | 39.6 |
|
| 329 |
+
| [OKVQA][okvqa] | 51.0 | 58.7 | 60.2 |
|
| 330 |
+
| [TallyQA][tallyqa] | 42.5 | 51.8 | 54.3 |
|
| 331 |
+
| [SpatialSense VQA][ss-vqa] | 50.9 | 60.0 | 59.4 |
|
| 332 |
+
| [CountBenchQA][countbenchqa] | 26.1 | 17.8 | 68.0 |
|
| 333 |
+
|
| 334 |
+
[coco-cap]: https://cocodataset.org/#home
|
| 335 |
+
[docvqa]: https://www.docvqa.org/
|
| 336 |
+
[info-vqa]: https://arxiv.org/abs/2104.12756
|
| 337 |
+
[mmmu]: https://arxiv.org/abs/2311.16502
|
| 338 |
+
[textvqa]: https://textvqa.org/
|
| 339 |
+
[realworldqa]: https://paperswithcode.com/dataset/realworldqa
|
| 340 |
+
[remi]: https://arxiv.org/html/2406.09175v1
|
| 341 |
+
[ai2d]: https://allenai.org/data/diagrams
|
| 342 |
+
[chartqa]: https://arxiv.org/abs/2203.10244
|
| 343 |
+
[vqav2]: https://visualqa.org/index.html
|
| 344 |
+
[blinkvqa]: https://arxiv.org/abs/2404.12390
|
| 345 |
+
[okvqa]: https://okvqa.allenai.org/
|
| 346 |
+
[tallyqa]: https://arxiv.org/abs/1810.12440
|
| 347 |
+
[ss-vqa]: https://arxiv.org/abs/1908.02660
|
| 348 |
+
[countbenchqa]: https://github.com/google-research/big_vision/blob/main/big_vision/datasets/countbenchqa/
|
| 349 |
+
|
| 350 |
+
## Ethics and Safety
|
| 351 |
+
|
| 352 |
+
Ethics and safety evaluation approach and results.
|
| 353 |
+
|
| 354 |
+
### Evaluation Approach
|
| 355 |
+
|
| 356 |
+
Our evaluation methods include structured evaluations and internal red-teaming
|
| 357 |
+
testing of relevant content policies. Red-teaming was conducted by a number of
|
| 358 |
+
different teams, each with different goals and human evaluation metrics. These
|
| 359 |
+
models were evaluated against a number of different categories relevant to
|
| 360 |
+
ethics and safety, including:
|
| 361 |
+
|
| 362 |
+
- **Child Safety**: Evaluation of text-to-text and image to text prompts
|
| 363 |
+
covering child safety policies, including child sexual abuse and
|
| 364 |
+
exploitation.
|
| 365 |
+
- **Content Safety:** Evaluation of text-to-text and image to text prompts
|
| 366 |
+
covering safety policies including, harassment, violence and gore, and hate
|
| 367 |
+
speech.
|
| 368 |
+
- **Representational Harms**: Evaluation of text-to-text and image to text
|
| 369 |
+
prompts covering safety policies including bias, stereotyping, and harmful
|
| 370 |
+
associations or inaccuracies.
|
| 371 |
+
|
| 372 |
+
In addition to development level evaluations, we conduct "assurance
|
| 373 |
+
evaluations" which are our 'arms-length' internal evaluations for responsibility
|
| 374 |
+
governance decision making. They are conducted separately from the model
|
| 375 |
+
development team, to inform decision making about release. High level findings
|
| 376 |
+
are fed back to the model team, but prompt sets are held-out to prevent
|
| 377 |
+
overfitting and preserve the results' ability to inform decision making.
|
| 378 |
+
Assurance evaluation results are reported to our Responsibility & Safety Council
|
| 379 |
+
as part of release review.
|
| 380 |
+
|
| 381 |
+
### Evaluation Results
|
| 382 |
+
|
| 383 |
+
For all areas of safety testing, we saw major improvements in the categories of
|
| 384 |
+
child safety, content safety, and representational harms relative to previous
|
| 385 |
+
Gemma models. All testing was conducted without safety filters to evaluate the
|
| 386 |
+
model capabilities and behaviors. For both text-to-text and image-to-text, and
|
| 387 |
+
across all model sizes, the model produced minimal policy violations, and showed
|
| 388 |
+
significant improvements over previous Gemma models' performance with respect
|
| 389 |
+
to ungrounded inferences. A limitation of our evaluations was they included only
|
| 390 |
+
English language prompts.
|
| 391 |
+
|
| 392 |
+
## Usage and Limitations
|
| 393 |
+
|
| 394 |
+
These models have certain limitations that users should be aware of.
|
| 395 |
+
|
| 396 |
+
### Intended Usage
|
| 397 |
+
|
| 398 |
+
Open vision-language models (VLMs) models have a wide range of applications
|
| 399 |
+
across various industries and domains. The following list of potential uses is
|
| 400 |
+
not comprehensive. The purpose of this list is to provide contextual information
|
| 401 |
+
about the possible use-cases that the model creators considered as part of model
|
| 402 |
+
training and development.
|
| 403 |
+
|
| 404 |
+
- Content Creation and Communication
|
| 405 |
+
- Text Generation: These models can be used to generate creative text
|
| 406 |
+
formats such as poems, scripts, code, marketing copy, and email drafts.
|
| 407 |
+
- Chatbots and Conversational AI: Power conversational interfaces
|
| 408 |
+
for customer service, virtual assistants, or interactive applications.
|
| 409 |
+
- Text Summarization: Generate concise summaries of a text corpus,
|
| 410 |
+
research papers, or reports.
|
| 411 |
+
- Image Data Extraction: These models can be used to extract,
|
| 412 |
+
interpret, and summarize visual data for text communications.
|
| 413 |
+
- Research and Education
|
| 414 |
+
- Natural Language Processing (NLP) and VLM Research: These
|
| 415 |
+
models can serve as a foundation for researchers to experiment with VLM
|
| 416 |
+
and NLP techniques, develop algorithms, and contribute to the
|
| 417 |
+
advancement of the field.
|
| 418 |
+
- Language Learning Tools: Support interactive language learning
|
| 419 |
+
experiences, aiding in grammar correction or providing writing practice.
|
| 420 |
+
- Knowledge Exploration: Assist researchers in exploring large
|
| 421 |
+
bodies of text by generating summaries or answering questions about
|
| 422 |
+
specific topics.
|
| 423 |
+
|
| 424 |
+
### Limitations
|
| 425 |
+
|
| 426 |
+
- Training Data
|
| 427 |
+
- The quality and diversity of the training data significantly
|
| 428 |
+
influence the model's capabilities. Biases or gaps in the training data
|
| 429 |
+
can lead to limitations in the model's responses.
|
| 430 |
+
- The scope of the training dataset determines the subject areas
|
| 431 |
+
the model can handle effectively.
|
| 432 |
+
- Context and Task Complexity
|
| 433 |
+
- Models are better at tasks that can be framed with clear
|
| 434 |
+
prompts and instructions. Open-ended or highly complex tasks might be
|
| 435 |
+
challenging.
|
| 436 |
+
- A model's performance can be influenced by the amount of context
|
| 437 |
+
provided (longer context generally leads to better outputs, up to a
|
| 438 |
+
certain point).
|
| 439 |
+
- Language Ambiguity and Nuance
|
| 440 |
+
- Natural language is inherently complex. Models might struggle
|
| 441 |
+
to grasp subtle nuances, sarcasm, or figurative language.
|
| 442 |
+
- Factual Accuracy
|
| 443 |
+
- Models generate responses based on information they learned
|
| 444 |
+
from their training datasets, but they are not knowledge bases. They
|
| 445 |
+
may generate incorrect or outdated factual statements.
|
| 446 |
+
- Common Sense
|
| 447 |
+
- Models rely on statistical patterns in language. They might
|
| 448 |
+
lack the ability to apply common sense reasoning in certain situations.
|
| 449 |
+
|
| 450 |
+
### Ethical Considerations and Risks
|
| 451 |
+
|
| 452 |
+
The development of vision-language models (VLMs) raises several ethical
|
| 453 |
+
concerns. In creating an open model, we have carefully considered the following:
|
| 454 |
+
|
| 455 |
+
- Bias and Fairness
|
| 456 |
+
- VLMs trained on large-scale, real-world text and image data can
|
| 457 |
+
reflect socio-cultural biases embedded in the training material. These
|
| 458 |
+
models underwent careful scrutiny, input data pre-processing described
|
| 459 |
+
and posterior evaluations reported in this card.
|
| 460 |
+
- Misinformation and Misuse
|
| 461 |
+
- VLMs can be misused to generate text that is false, misleading,
|
| 462 |
+
or harmful.
|
| 463 |
+
- Guidelines are provided for responsible use with the model, see the
|
| 464 |
+
[Responsible Generative AI Toolkit][rai-toolkit].
|
| 465 |
+
- Transparency and Accountability:
|
| 466 |
+
- This model card summarizes details on the models' architecture,
|
| 467 |
+
capabilities, limitations, and evaluation processes.
|
| 468 |
+
- A responsibly developed open model offers the opportunity to
|
| 469 |
+
share innovation by making VLM technology accessible to developers and
|
| 470 |
+
researchers across the AI ecosystem.
|
| 471 |
+
|
| 472 |
+
Risks identified and mitigations:
|
| 473 |
+
|
| 474 |
+
- **Perpetuation of biases**: It's encouraged to perform continuous
|
| 475 |
+
monitoring (using evaluation metrics, human review) and the exploration of
|
| 476 |
+
de-biasing techniques during model training, fine-tuning, and other use
|
| 477 |
+
cases.
|
| 478 |
+
- **Generation of harmful content**: Mechanisms and guidelines for content
|
| 479 |
+
safety are essential. Developers are encouraged to exercise caution and
|
| 480 |
+
implement appropriate content safety safeguards based on their specific
|
| 481 |
+
product policies and application use cases.
|
| 482 |
+
- **Misuse for malicious purposes**: Technical limitations and developer
|
| 483 |
+
and end-user education can help mitigate against malicious applications of
|
| 484 |
+
VLMs. Educational resources and reporting mechanisms for users to flag
|
| 485 |
+
misuse are provided. Prohibited uses of Gemma models are outlined in the
|
| 486 |
+
[Gemma Prohibited Use Policy][prohibited-use].
|
| 487 |
+
- **Privacy violations**: Models were trained on data filtered for removal
|
| 488 |
+
of certain personal information and other sensitive data. Developers are
|
| 489 |
+
encouraged to adhere to privacy regulations with privacy-preserving
|
| 490 |
+
techniques.
|
| 491 |
+
|
| 492 |
+
### Benefits
|
| 493 |
+
|
| 494 |
+
At the time of release, this family of models provides high-performance open
|
| 495 |
+
vision-language model implementations designed from the ground up for
|
| 496 |
+
responsible AI development compared to similarly sized models.
|
| 497 |
+
|
| 498 |
+
Using the benchmark evaluation metrics described in this document, these models
|
| 499 |
+
have shown to provide superior performance to other, comparably-sized open model
|
| 500 |
+
alternatives.
|
| 501 |
+
|
| 502 |
+
[g3-tech-report]: https://goo.gle/Gemma3Report
|
| 503 |
+
[rai-toolkit]: https://ai.google.dev/responsible
|
| 504 |
+
[kaggle-gemma]: https://www.kaggle.com/models/google/gemma-3
|
| 505 |
+
[vertex-mg-gemma3]: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/gemma3
|
| 506 |
+
[terms]: https://ai.google.dev/gemma/terms
|
| 507 |
+
[safety-policies]: https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf
|
| 508 |
+
[prohibited-use]: https://ai.google.dev/gemma/prohibited_use_policy
|
| 509 |
+
[tpu]: https://cloud.google.com/tpu/docs/intro-to-tpu
|
| 510 |
+
[sustainability]: https://sustainability.google/operating-sustainably/
|
| 511 |
+
[jax]: https://github.com/jax-ml/jax
|
| 512 |
+
[ml-pathways]: https://blog.google/technology/ai/introducing-pathways-next-generation-ai-architecture/
|
| 513 |
+
[sustainability]: https://sustainability.google/operating-sustainably/
|
| 514 |
+
[gemini-2-paper]: https://arxiv.org/abs/2312.11805
|
gemma3_tokenizer/added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<image_soft_token>": 262144
|
| 3 |
+
}
|
gemma3_tokenizer/config.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Gemma3ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"attn_logit_softcapping": null,
|
| 8 |
+
"bos_token_id": 2,
|
| 9 |
+
"cache_implementation": "hybrid",
|
| 10 |
+
"eos_token_id": [
|
| 11 |
+
1,
|
| 12 |
+
106
|
| 13 |
+
],
|
| 14 |
+
"final_logit_softcapping": null,
|
| 15 |
+
"head_dim": 256,
|
| 16 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
| 17 |
+
"hidden_size": 1152,
|
| 18 |
+
"initializer_range": 0.02,
|
| 19 |
+
"intermediate_size": 6912,
|
| 20 |
+
"max_position_embeddings": 32768,
|
| 21 |
+
"model_type": "gemma3_text",
|
| 22 |
+
"num_attention_heads": 4,
|
| 23 |
+
"num_hidden_layers": 26,
|
| 24 |
+
"num_key_value_heads": 1,
|
| 25 |
+
"pad_token_id": 0,
|
| 26 |
+
"query_pre_attn_scalar": 256,
|
| 27 |
+
"rms_norm_eps": 1e-06,
|
| 28 |
+
"rope_local_base_freq": 10000,
|
| 29 |
+
"rope_scaling": null,
|
| 30 |
+
"rope_theta": 1000000,
|
| 31 |
+
"sliding_window": 512,
|
| 32 |
+
"sliding_window_pattern": 6,
|
| 33 |
+
"torch_dtype": "bfloat16",
|
| 34 |
+
"transformers_version": "4.50.0.dev0",
|
| 35 |
+
"use_cache": true,
|
| 36 |
+
"vocab_size": 262144
|
| 37 |
+
}
|
gemma3_tokenizer/generation_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 2,
|
| 3 |
+
"cache_implementation": "hybrid",
|
| 4 |
+
"do_sample": true,
|
| 5 |
+
"eos_token_id": [
|
| 6 |
+
1,
|
| 7 |
+
106
|
| 8 |
+
],
|
| 9 |
+
"pad_token_id": 0,
|
| 10 |
+
"top_k": 64,
|
| 11 |
+
"top_p": 0.95,
|
| 12 |
+
"transformers_version": "4.50.0.dev0"
|
| 13 |
+
}
|
gemma3_tokenizer/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d4ef8d71c14db7e448a09ebe891cfb6bf32c57a9b44499ae0d1c098e48516b6
|
| 3 |
+
size 1999811208
|
gemma3_tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"boi_token": "<start_of_image>",
|
| 3 |
+
"bos_token": {
|
| 4 |
+
"content": "<bos>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
},
|
| 10 |
+
"eoi_token": "<end_of_image>",
|
| 11 |
+
"eos_token": {
|
| 12 |
+
"content": "<eos>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false
|
| 17 |
+
},
|
| 18 |
+
"image_token": "<image_soft_token>",
|
| 19 |
+
"pad_token": {
|
| 20 |
+
"content": "<pad>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false
|
| 25 |
+
},
|
| 26 |
+
"unk_token": {
|
| 27 |
+
"content": "<unk>",
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"normalized": false,
|
| 30 |
+
"rstrip": false,
|
| 31 |
+
"single_word": false
|
| 32 |
+
}
|
| 33 |
+
}
|
gemma3_tokenizer/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
|
| 3 |
+
size 33384568
|
gemma3_tokenizer/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
|
| 3 |
+
size 4689074
|
gemma3_tokenizer/tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
infer_axmodel.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoConfig, AutoTokenizer
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
from axengine import InferenceSession
|
| 7 |
+
from ml_dtypes import bfloat16
|
| 8 |
+
from utils.infer_func import InferManager
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
|
| 14 |
+
parser = argparse.ArgumentParser(description="Model configuration parameters")
|
| 15 |
+
parser.add_argument("--hf_model", type=str, default="./gemma3_tokenizer/",
|
| 16 |
+
help="Path to HuggingFace model")
|
| 17 |
+
parser.add_argument("--axmodel_path", type=str, default="./gemma3_axmodel/",
|
| 18 |
+
help="Path to save compiled axmodel of llama model")
|
| 19 |
+
parser.add_argument("-q", "--question", type=str, default="请用中文介绍一下你自己.",
|
| 20 |
+
help="Your question that you want to ask the model.")
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
hf_model_path = args.hf_model
|
| 24 |
+
axmodel_path = args.axmodel_path
|
| 25 |
+
prompt = args.question
|
| 26 |
+
|
| 27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
embeds = np.load(os.path.join(axmodel_path, "model.embed_tokens.weight.npy"))
|
| 29 |
+
|
| 30 |
+
# load the tokenizer and the model
|
| 31 |
+
tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
|
| 32 |
+
cfg = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=True)
|
| 33 |
+
|
| 34 |
+
eos_token_id = None
|
| 35 |
+
if isinstance(cfg.eos_token_id, list) and len(cfg.eos_token_id) > 1:
|
| 36 |
+
eos_token_id = cfg.eos_token_id
|
| 37 |
+
|
| 38 |
+
######################################################################
|
| 39 |
+
# Gemma-3
|
| 40 |
+
if "gemma3" in cfg.model_type.lower():
|
| 41 |
+
messages = [
|
| 42 |
+
[
|
| 43 |
+
{
|
| 44 |
+
"role": "system",
|
| 45 |
+
"content": [{"type": "text", "text": "You are a helpful assistant."},]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"role": "user",
|
| 49 |
+
"content": [{"type": "text", "text": prompt},]
|
| 50 |
+
},
|
| 51 |
+
],
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
model_inputs = tokenizer.apply_chat_template(
|
| 55 |
+
messages,
|
| 56 |
+
add_generation_prompt=True,
|
| 57 |
+
tokenize=True,
|
| 58 |
+
return_dict=True,
|
| 59 |
+
return_tensors="pt",
|
| 60 |
+
)
|
| 61 |
+
# model_inputs = tokenizer([text], return_tensors="pt").to(device)
|
| 62 |
+
input_ids = model_inputs.input_ids
|
| 63 |
+
|
| 64 |
+
# Gemma-2
|
| 65 |
+
if "gemma2" in cfg.model_type.lower():
|
| 66 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
| 67 |
+
######################################################################
|
| 68 |
+
|
| 69 |
+
token_ids = input_ids[0].cpu().numpy().tolist()
|
| 70 |
+
token_len = len(token_ids)
|
| 71 |
+
prefill_data = np.take(embeds, token_ids, axis=0)
|
| 72 |
+
prefill_data = prefill_data.astype(bfloat16)
|
| 73 |
+
|
| 74 |
+
imer = InferManager(cfg, axmodel_path)
|
| 75 |
+
|
| 76 |
+
token_ids = imer.prefill(tokenizer, token_ids, prefill_data, slice_len=128)
|
| 77 |
+
imer.decode(tokenizer, token_ids, embeds, slice_len=128, eos_token_id=eos_token_id)
|
| 78 |
+
print("\n")
|
utils/infer_func.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from axengine import InferenceSession
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
from ml_dtypes import bfloat16
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Discover model files automatically from model_dir.
|
| 12 |
+
# We expect files like: <prefix>_p128_l<idx>_together.axmodel and <prefix>_post.axmodel
|
| 13 |
+
# we try to detect model prefix and layer files automatically
|
| 14 |
+
def _find_axmodel_files(base_dir: str, expected_layers: int = None):
|
| 15 |
+
files = os.listdir(base_dir)
|
| 16 |
+
layer_pattern = re.compile(r"^(?P<prefix>.*)_p128_l(?P<idx>\d+)_together\.axmodel$")
|
| 17 |
+
post_pattern = re.compile(r"^(?P<prefix>.*)_post\.axmodel$")
|
| 18 |
+
|
| 19 |
+
# collect prefix -> [(idx, fname)]
|
| 20 |
+
prefix_map = {}
|
| 21 |
+
for fname in files:
|
| 22 |
+
m = layer_pattern.match(fname)
|
| 23 |
+
if m:
|
| 24 |
+
prefix = m.group("prefix")
|
| 25 |
+
idx = int(m.group("idx"))
|
| 26 |
+
prefix_map.setdefault(prefix, []).append((idx, fname))
|
| 27 |
+
|
| 28 |
+
if not prefix_map:
|
| 29 |
+
# fallback to hardcoded pattern if nothing detected
|
| 30 |
+
prefix = "gemma3_text"
|
| 31 |
+
layer_files = [(
|
| 32 |
+
i, f"{prefix}_p128_l{i}_together.axmodel"
|
| 33 |
+
) for i in range(expected_layers or 0)]
|
| 34 |
+
else:
|
| 35 |
+
# choose the prefix with the most layers (most likely the correct one)
|
| 36 |
+
prefix = max(prefix_map.items(), key=lambda kv: len(kv[1]))[0]
|
| 37 |
+
# debug info
|
| 38 |
+
print(f"Detected prefixes: {list(prefix_map.keys())}, chosen: {prefix}, layers: {len(prefix_map[prefix])}")
|
| 39 |
+
layer_files = sorted(prefix_map[prefix], key=lambda it: it[0])
|
| 40 |
+
|
| 41 |
+
# find post process file
|
| 42 |
+
post_file = None
|
| 43 |
+
for fname in files:
|
| 44 |
+
m = post_pattern.match(fname)
|
| 45 |
+
if m and m.group("prefix") == prefix:
|
| 46 |
+
post_file = fname
|
| 47 |
+
break
|
| 48 |
+
if post_file is None:
|
| 49 |
+
candidate = os.path.join(base_dir, f"{prefix}_post.axmodel")
|
| 50 |
+
if os.path.exists(candidate):
|
| 51 |
+
post_file = f"{prefix}_post.axmodel"
|
| 52 |
+
else:
|
| 53 |
+
for fname in files:
|
| 54 |
+
if fname.endswith("_post.axmodel"):
|
| 55 |
+
post_file = fname
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
return layer_files, post_file, prefix
|
| 59 |
+
|
| 60 |
+
class InferManager:
|
| 61 |
+
def __init__(self, config, model_dir):
|
| 62 |
+
|
| 63 |
+
self.config = config
|
| 64 |
+
self.max_seq_len = 2559
|
| 65 |
+
|
| 66 |
+
self.sub_dim = config.hidden_size // config.num_attention_heads if not config.head_dim else config.head_dim
|
| 67 |
+
self.kv_dim = self.sub_dim * config.num_key_value_heads
|
| 68 |
+
|
| 69 |
+
self.k_caches = [
|
| 70 |
+
np.zeros((1, self.max_seq_len, self.kv_dim), dtype=bfloat16)
|
| 71 |
+
for _ in range(config.num_hidden_layers)
|
| 72 |
+
]
|
| 73 |
+
self.v_caches = [
|
| 74 |
+
np.zeros((1, self.max_seq_len, self.kv_dim), dtype=bfloat16)
|
| 75 |
+
for _ in range(config.num_hidden_layers)
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
layer_files, post_file, prefix = _find_axmodel_files(model_dir, config.num_hidden_layers)
|
| 79 |
+
|
| 80 |
+
self.decoder_sessions = []
|
| 81 |
+
for _, fname in tqdm(layer_files, desc="Init InferenceSession"):
|
| 82 |
+
session = InferenceSession(os.path.join(model_dir, fname))
|
| 83 |
+
self.decoder_sessions.append(session)
|
| 84 |
+
|
| 85 |
+
# post_file was returned by _find_axmodel_files; ensure it was found
|
| 86 |
+
if post_file is None:
|
| 87 |
+
raise FileNotFoundError("Cannot find post process .axmodel file in model_dir")
|
| 88 |
+
self.post_process_session = InferenceSession(os.path.join(model_dir, post_file))
|
| 89 |
+
print("Model loaded successfully!")
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def _top_p(probs: np.ndarray, p: float) -> np.ndarray:
|
| 93 |
+
sorted_indices = np.argsort(probs)
|
| 94 |
+
filtered = probs.copy()
|
| 95 |
+
cumulative = 0
|
| 96 |
+
for idx in sorted_indices[::-1]:
|
| 97 |
+
if cumulative >= p:
|
| 98 |
+
filtered[idx] = 0
|
| 99 |
+
cumulative += filtered[idx]
|
| 100 |
+
return filtered / cumulative
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def _softmax(logits: np.ndarray) -> np.ndarray:
|
| 104 |
+
logits = logits - logits.max()
|
| 105 |
+
exp_logits = np.exp(logits)
|
| 106 |
+
return (exp_logits / np.sum(exp_logits)).astype(np.float64)
|
| 107 |
+
|
| 108 |
+
def post_process(self, logits, top_k=1, top_p=0.9, temperature=0.6):
|
| 109 |
+
logits = logits.astype(np.float32).flatten()
|
| 110 |
+
candidate_indices = np.argpartition(logits, -top_k)[-top_k:]
|
| 111 |
+
candidate_logits = logits[candidate_indices] / temperature
|
| 112 |
+
candidate_probs = self._softmax(candidate_logits)
|
| 113 |
+
candidate_probs = self._top_p(candidate_probs, top_p)
|
| 114 |
+
candidate_probs = candidate_probs.astype(np.float64) / candidate_probs.sum()
|
| 115 |
+
chosen_idx = np.random.multinomial(1, candidate_probs).argmax()
|
| 116 |
+
next_token = candidate_indices[chosen_idx]
|
| 117 |
+
return next_token, candidate_indices, candidate_probs
|
| 118 |
+
|
| 119 |
+
def gen_slice_indices(self, token_len, prefill=128, expand=128):
|
| 120 |
+
remaining = max(0, token_len - prefill)
|
| 121 |
+
extra_blocks = (remaining + expand - 1) // expand
|
| 122 |
+
return list(range(extra_blocks + 1))
|
| 123 |
+
|
| 124 |
+
def prefill(
|
| 125 |
+
self,
|
| 126 |
+
tokenizer,
|
| 127 |
+
token_ids,
|
| 128 |
+
embed_data,
|
| 129 |
+
slice_len=128,
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Prefill step for chunked inference.
|
| 133 |
+
"""
|
| 134 |
+
seq_len = len(token_ids)
|
| 135 |
+
slice_indices = [i for i in range(seq_len // slice_len + 1)]
|
| 136 |
+
print(f"slice_indices: {slice_indices}")
|
| 137 |
+
# total_prefill_len = (
|
| 138 |
+
# slice_len * slice_indices[-1]
|
| 139 |
+
# if slice_indices[-1] != 0
|
| 140 |
+
# else slice_len
|
| 141 |
+
# )
|
| 142 |
+
total_prefill_len = slice_len * (slice_indices[-1] + 1)
|
| 143 |
+
# slice_indices = self.gen_slice_indices(seq_len)
|
| 144 |
+
|
| 145 |
+
if total_prefill_len > 0:
|
| 146 |
+
for slice_idx in slice_indices:
|
| 147 |
+
indices = np.arange(
|
| 148 |
+
slice_idx * slice_len,
|
| 149 |
+
(slice_idx + 1) * slice_len,
|
| 150 |
+
dtype=np.uint32
|
| 151 |
+
).reshape((1, slice_len))
|
| 152 |
+
|
| 153 |
+
mask = (
|
| 154 |
+
np.zeros((1, slice_len, slice_len * (slice_idx + 1)))
|
| 155 |
+
- 65536
|
| 156 |
+
)
|
| 157 |
+
data = np.zeros((1, slice_len, self.config.hidden_size)).astype(bfloat16)
|
| 158 |
+
for i, t in enumerate(
|
| 159 |
+
range(
|
| 160 |
+
slice_idx * slice_len,
|
| 161 |
+
(slice_idx + 1) * slice_len,
|
| 162 |
+
)
|
| 163 |
+
):
|
| 164 |
+
if t < len(token_ids):
|
| 165 |
+
mask[:, i, : slice_idx * slice_len + i + 1] = 0
|
| 166 |
+
data[:, i : i + 1, :] = (
|
| 167 |
+
embed_data[t]
|
| 168 |
+
.reshape((1, 1, self.config.hidden_size))
|
| 169 |
+
.astype(bfloat16)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
remain_len = (
|
| 173 |
+
seq_len - slice_idx * slice_len
|
| 174 |
+
if slice_idx == slice_indices[-1]
|
| 175 |
+
else slice_len
|
| 176 |
+
)
|
| 177 |
+
mask = mask.astype(bfloat16)
|
| 178 |
+
for layer_idx in range(self.config.num_hidden_layers):
|
| 179 |
+
input_feed = {
|
| 180 |
+
"K_cache": (
|
| 181 |
+
self.k_caches[layer_idx][:, 0 : slice_len * slice_idx, :]
|
| 182 |
+
if slice_idx
|
| 183 |
+
else np.zeros((1, 1, self.config.hidden_size), dtype=bfloat16)
|
| 184 |
+
),
|
| 185 |
+
"V_cache": (
|
| 186 |
+
self.v_caches[layer_idx][:, 0 : slice_len * slice_idx, :]
|
| 187 |
+
if slice_idx
|
| 188 |
+
else np.zeros((1, 1, self.config.hidden_size), dtype=bfloat16)
|
| 189 |
+
),
|
| 190 |
+
"indices": indices,
|
| 191 |
+
"input": data,
|
| 192 |
+
"mask": mask,
|
| 193 |
+
}
|
| 194 |
+
outputs = self.decoder_sessions[layer_idx].run(None, input_feed, shape_group=slice_idx + 1)
|
| 195 |
+
self.k_caches[layer_idx][
|
| 196 |
+
:,
|
| 197 |
+
slice_idx * slice_len : slice_idx * slice_len + remain_len,
|
| 198 |
+
:,
|
| 199 |
+
] = outputs[0][:, :remain_len, :]
|
| 200 |
+
self.v_caches[layer_idx][
|
| 201 |
+
:,
|
| 202 |
+
slice_idx * slice_len : slice_idx * slice_len + remain_len,
|
| 203 |
+
:,
|
| 204 |
+
] = outputs[1][:, :remain_len, :]
|
| 205 |
+
data = outputs[2]
|
| 206 |
+
|
| 207 |
+
print("Slice prefill done:", slice_idx)
|
| 208 |
+
|
| 209 |
+
# return data[:, :remain_len, :]
|
| 210 |
+
post_out = self.post_process_session.run(
|
| 211 |
+
None,
|
| 212 |
+
{
|
| 213 |
+
"input": data[
|
| 214 |
+
:, seq_len - (len(slice_indices) - 1) * slice_len - 1, None, :
|
| 215 |
+
]
|
| 216 |
+
}
|
| 217 |
+
)[0]
|
| 218 |
+
next_token, possible_tokens, possible_probs = self.post_process(post_out)
|
| 219 |
+
possible_decoded = [tokenizer.decode([t]) for t in possible_tokens]
|
| 220 |
+
possible_probs_str = [str((t, p)) for t, p in zip(possible_decoded, possible_probs)]
|
| 221 |
+
token_ids.append(next_token)
|
| 222 |
+
return token_ids
|
| 223 |
+
|
| 224 |
+
def decode(
|
| 225 |
+
self,
|
| 226 |
+
tokenizer,
|
| 227 |
+
token_ids,
|
| 228 |
+
embed_matrix,
|
| 229 |
+
prefill_len=128,
|
| 230 |
+
slice_len=128,
|
| 231 |
+
eos_token_id=None, # 某些模型有多个 eos_token_id
|
| 232 |
+
):
|
| 233 |
+
print("answer >>", tokenizer.decode(token_ids[-1], skip_special_tokens=True), end='', flush=True)
|
| 234 |
+
mask = np.zeros((1, 1, self.max_seq_len + 1), dtype=np.float32).astype(bfloat16)
|
| 235 |
+
mask[:, :, :self.max_seq_len] -= 65536
|
| 236 |
+
seq_len = len(token_ids) - 1
|
| 237 |
+
if prefill_len > 0:
|
| 238 |
+
mask[:, :, :seq_len] = 0
|
| 239 |
+
for step_idx in range(self.max_seq_len):
|
| 240 |
+
if prefill_len > 0 and step_idx < seq_len:
|
| 241 |
+
continue
|
| 242 |
+
cur_token = token_ids[step_idx]
|
| 243 |
+
indices = np.array([step_idx], np.uint32).reshape((1, 1))
|
| 244 |
+
data = embed_matrix[cur_token, :].reshape((1, 1, self.config.hidden_size)).astype(bfloat16)
|
| 245 |
+
for layer_idx in range(self.config.num_hidden_layers):
|
| 246 |
+
input_feed = {
|
| 247 |
+
"K_cache": self.k_caches[layer_idx],
|
| 248 |
+
"V_cache": self.v_caches[layer_idx],
|
| 249 |
+
"indices": indices,
|
| 250 |
+
"input": data,
|
| 251 |
+
"mask": mask,
|
| 252 |
+
}
|
| 253 |
+
outputs = self.decoder_sessions[layer_idx].run(None, input_feed, shape_group=0)
|
| 254 |
+
self.k_caches[layer_idx][:, step_idx, :] = outputs[0][:, :, :]
|
| 255 |
+
self.v_caches[layer_idx][:, step_idx, :] = outputs[1][:, :, :]
|
| 256 |
+
data = outputs[2]
|
| 257 |
+
mask[..., step_idx] = 0
|
| 258 |
+
if step_idx < seq_len - 1:
|
| 259 |
+
continue
|
| 260 |
+
else:
|
| 261 |
+
post_out = self.post_process_session.run(None, {"input": data})[0]
|
| 262 |
+
next_token, possible_tokens, possible_probs = self.post_process(post_out)
|
| 263 |
+
if eos_token_id is not None and next_token in eos_token_id:
|
| 264 |
+
break
|
| 265 |
+
elif next_token == tokenizer.eos_token_id:
|
| 266 |
+
break
|
| 267 |
+
else:
|
| 268 |
+
pass
|
| 269 |
+
token_ids.append(next_token)
|
| 270 |
+
print(tokenizer.decode(next_token, skip_special_tokens=True), end='', flush=True)
|
| 271 |
+
|