Oleg commited on
Commit
ebc8af3
·
1 Parent(s): 9128032

Initial commit - transformed model to onnx format

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ onnx/berta-onnx/BERTA.onnx.data filter=lfs diff=lfs merge=lfs -text
37
+ onnx/berta-onnx/BERTA.onnx filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+ .idea
12
+ uv.lock
README.md CHANGED
@@ -1,3 +1,172 @@
1
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: mit
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - ru
4
+ - en
5
+
6
+ pipeline_tag: sentence-similarity
7
+
8
+ tags:
9
+ - russian
10
+ - pretraining
11
+ - embeddings
12
+ - feature-extraction
13
+ - sentence-similarity
14
+ - sentence-transformers
15
+ - transformers
16
+
17
+ datasets:
18
+ - IlyaGusev/gazeta
19
+ - zloelias/lenta-ru
20
+ - HuggingFaceFW/fineweb-2
21
+ - HuggingFaceFW/fineweb
22
+
23
  license: mit
24
+ base_model: sergeyzh/LaBSE-ru-turbo
25
+
26
  ---
27
+
28
+ ## Репозиторий модели Berta, конвертированной в формат onnx
29
+ Репозиторий оригинальной модели: https://huggingface.co/sergeyzh/BERTA
30
+
31
+ ## BERTA
32
+
33
+ Модель для расчетов эмбеддингов предложений на русском и английском языках получена методом дистилляции эмбеддингов [ai-forever/FRIDA](https://huggingface.co/ai-forever/FRIDA) (размер эмбеддингов - 1536, слоёв - 24) в [sergeyzh/LaBSE-ru-turbo](https://huggingface.co/sergeyzh/LaBSE-ru-turbo) (размер эмбеддингов - 768, слоёв - 12). Основной режим использования FRIDA - CLS pooling заменен на mean pooling. Каких-либо других изменений поведения модели не производилось. Дистиляция выполнена в максимально возможном объеме - эмбеддинги русских и английских предложений, работа префиксов.
34
+
35
+ Размер контекста модели соответствует FRIDA - 512 токенов.
36
+
37
+ ## Префиксы
38
+ Все префиксы унаследованы от FRIDA.
39
+ Оптимальный (обеспечивающий средние результаты) префикс для большинства задач - "categorize_entailment: " прописан по умолчанию в [config_sentence_transformers.json](https://huggingface.co/sergeyzh/BERTA/blob/main/config_sentence_transformers.json)
40
+
41
+ Перечень используемых префиксов и их влияние на оценки модели в [encodechka](https://github.com/avidale/encodechka):
42
+
43
+ | Префикс | STS | PI | NLI | SA | TI |
44
+ |:-----------------------|:---------:|:---------:|:---------:|:---------:|:---------:|
45
+ | - | 0,842 | 0,757 | 0,463 | **0,830** | 0,985 |
46
+ | search_query: | 0,853 | 0,767 | 0,479 | 0,825 | 0,987 |
47
+ | search_document: | 0,831 | 0,749 | 0,463 | 0,817 | 0,986 |
48
+ | paraphrase: | 0,847 | **0,778** | 0,446 | 0,825 | 0,986 |
49
+ | categorize: | **0,857** | 0,765 | 0,501 | 0,829 | **0,988** |
50
+ | categorize_sentiment: | 0,589 | 0,535 | 0,417 | 0,805 | 0,982 |
51
+ | categorize_topic: | 0,740 | 0,521 | 0,396 | 0,770 | 0,982 |
52
+ | categorize_entailment: | 0,841 | 0,762 | **0,571** | 0,827 | 0,986 |
53
+
54
+
55
+ **Задачи:**
56
+
57
+ - Semantic text similarity (**STS**);
58
+ - Paraphrase identification (**PI**);
59
+ - Natural language inference (**NLI**);
60
+ - Sentiment analysis (**SA**);
61
+ - Toxicity identification (**TI**).
62
+
63
+
64
+
65
+ # Метрики
66
+ Оценки модели на бенчмарке [ruMTEB](https://habr.com/ru/companies/sberdevices/articles/831150/):
67
+
68
+ |Model Name | Metric | FRIDA | BERTA | [rubert-mini-frida](https://huggingface.co/sergeyzh/rubert-mini-frida) | multilingual-e5-large-instruct | multilingual-e5-large |
69
+ |:-------------------------------|:--------------------|----------:|----------:|--------------------:|---------------------:|----------------------:|
70
+ |CEDRClassification | Accuracy | **0.646** | 0.622 | 0.552 | 0.500 | 0.448 |
71
+ |GeoreviewClassification | Accuracy | **0.577** | 0.548 | 0.464 | 0.559 | 0.497 |
72
+ |GeoreviewClusteringP2P | V-measure | **0.783** | 0.738 | 0.698 | 0.743 | 0.605 |
73
+ |HeadlineClassification | Accuracy | 0.890 | **0.891** | 0.880 | 0.862 | 0.758 |
74
+ |InappropriatenessClassification | Accuracy | **0.783** | 0.748 | 0.698 | 0.655 | 0.616 |
75
+ |KinopoiskClassification | Accuracy | **0.705** | 0.678 | 0.595 | 0.661 | 0.566 |
76
+ |RiaNewsRetrieval | NDCG@10 | **0.868** | 0.816 | 0.721 | 0.824 | 0.807 |
77
+ |RuBQReranking | MAP@10 | **0.771** | 0.752 | 0.711 | 0.717 | 0.756 |
78
+ |RuBQRetrieval | NDCG@10 | 0.724 | 0.710 | 0.654 | 0.692 | **0.741** |
79
+ |RuReviewsClassification | Accuracy | **0.751** | 0.723 | 0.658 | 0.686 | 0.653 |
80
+ |RuSTSBenchmarkSTS | Pearson correlation | 0.814 | 0.822 | 0.803 | **0.840** | 0.831 |
81
+ |RuSciBenchGRNTIClassification | Accuracy | **0.699** | 0.690 | 0.625 | 0.651 | 0.582 |
82
+ |RuSciBenchGRNTIClusteringP2P | V-measure | **0.670** | 0.650 | 0.586 | 0.622 | 0.520 |
83
+ |RuSciBenchOECDClassification | Accuracy | 0.546 | **0.555** | 0.493 | 0.502 | 0.445 |
84
+ |RuSciBenchOECDClusteringP2P | V-measure | **0.566** | 0.556 | 0.507 | 0.528 | 0.450 |
85
+ |SensitiveTopicsClassification | Accuracy | 0.398 | **0.399** | 0.373 | 0.323 | 0.257 |
86
+ |TERRaClassification | Average Precision | **0.665** | 0.657 | 0.606 | 0.639 | 0.584 |
87
+
88
+ |Model Name | Metric | FRIDA | BERTA | rubert-mini-frida | multilingual-e5-large-instruct | multilingual-e5-large |
89
+ |:-------------------------------|:--------------------|----------:|----------:|--------------------:|----------------------:|---------------------:|
90
+ |Classification | Accuracy | **0.707** | 0.698 | 0.631 | 0.654 | 0.588 |
91
+ |Clustering | V-measure | **0.673** | 0.648 | 0.597 | 0.631 | 0.525 |
92
+ |MultiLabelClassification | Accuracy | **0.522** | 0.510 | 0.463 | 0.412 | 0.353 |
93
+ |PairClassification | Average Precision | **0.665** | 0.657 | 0.606 | 0.639 | 0.584 |
94
+ |Reranking | MAP@10 | **0.771** | 0.752 | 0.711 | 0.717 | 0.756 |
95
+ |Retrieval | NDCG@10 | **0.796** | 0.763 | 0.687 | 0.758 | 0.774 |
96
+ |STS | Pearson correlation | 0.814 | 0.822 | 0.803 | **0.840** | 0.831 |
97
+ |Average | Average | **0.707** | 0.693 | 0.643 | 0.664 | 0.630 |
98
+
99
+
100
+
101
+ ## Использование модели с библиотекой `transformers`:
102
+
103
+ ```python
104
+ import torch
105
+ import torch.nn.functional as F
106
+ from transformers import AutoTokenizer, AutoModel
107
+
108
+
109
+ def pool(hidden_state, mask, pooling_method="mean"):
110
+ if pooling_method == "mean":
111
+ s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
112
+ d = mask.sum(axis=1, keepdim=True).float()
113
+ return s / d
114
+ elif pooling_method == "cls":
115
+ return hidden_state[:, 0]
116
+
117
+ inputs = [
118
+ #
119
+ "paraphrase: В Ярославской области разрешили работу бань, но без посетителей",
120
+ "categorize_entailment: Женщину доставили в больницу, за ее жизнь сейчас борются врачи.",
121
+ "search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
122
+ #
123
+ "paraphrase: Ярославским баням разрешили работать без посетителей",
124
+ "categorize_entailment: Женщину спасают врачи.",
125
+ "search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование."
126
+ ]
127
+
128
+ tokenizer = AutoTokenizer.from_pretrained("sergeyzh/BERTA")
129
+ model = AutoModel.from_pretrained("sergeyzh/BERTA")
130
+
131
+ tokenized_inputs = tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors="pt")
132
+
133
+ with torch.no_grad():
134
+ outputs = model(**tokenized_inputs)
135
+
136
+ embeddings = pool(
137
+ outputs.last_hidden_state,
138
+ tokenized_inputs["attention_mask"],
139
+ pooling_method="mean"
140
+ )
141
+
142
+ embeddings = F.normalize(embeddings, p=2, dim=1)
143
+ sim_scores = embeddings[:3] @ embeddings[3:].T
144
+ print(sim_scores.diag().tolist())
145
+ # [0.9530372023582458, 0.866746723651886, 0.7839133143424988]
146
+ # [0.9360030293464661, 0.8591322302818298, 0.728583037853241] - FRIDA
147
+ ```
148
+
149
+ ## Использование с `sentence_transformers` (sentence-transformers>=2.4.0):
150
+
151
+ ```python
152
+ from sentence_transformers import SentenceTransformer
153
+
154
+ # loads model with mean pooling
155
+ model = SentenceTransformer("sergeyzh/BERTA")
156
+
157
+ paraphrase = model.encode(["В Ярославской области разрешили работу бань, но без посетителей", "Ярославским баням разрешили работать без посетителей"], prompt="paraphrase: ")
158
+ print(paraphrase[0] @ paraphrase[1].T)
159
+ # 0.9530372
160
+ # 0.9360032 - FRIDA
161
+
162
+ categorize_entailment = model.encode(["Женщину доставили в больницу, за ее жизнь сейчас борются врачи.", "Женщину спасают врачи."], prompt="categorize_entailment: ")
163
+ print(categorize_entailment[0] @ categorize_entailment[1].T)
164
+ # 0.8667469
165
+ # 0.8591322 - FRIDA
166
+
167
+ query_embedding = model.encode("Сколько программистов нужно, чтобы вкрутить лампочку?", prompt="search_query: ")
168
+ document_embedding = model.encode("Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.", prompt="search_document: ")
169
+ print(query_embedding @ document_embedding.T)
170
+ # 0.7839136
171
+ # 0.7285831 - FRIDA
172
+ ```
onnx/berta-onnx/BERTA.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:465286621e28ba4663fd22b84b90a1135efc95519ca34917f536cd87e6fa2b84
3
+ size 1222522
onnx/berta-onnx/BERTA.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c67530b9d0380f15d915fbf58d97b99c9cf56d6082fe96ae9ab36378783de195
3
+ size 513410048
onnx/berta-onnx/special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
onnx/berta-onnx/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
onnx/berta-onnx/tokenizer_config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": false,
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "max_length": 512,
51
+ "model_max_length": 512,
52
+ "never_split": null,
53
+ "pad_to_multiple_of": null,
54
+ "pad_token": "[PAD]",
55
+ "pad_token_type_id": 0,
56
+ "padding_side": "right",
57
+ "repo_type": "model",
58
+ "sep_token": "[SEP]",
59
+ "stride": 0,
60
+ "strip_accents": null,
61
+ "tokenize_chinese_chars": true,
62
+ "tokenizer_class": "BertTokenizer",
63
+ "truncation_side": "right",
64
+ "truncation_strategy": "longest_first",
65
+ "unk_token": "[UNK]"
66
+ }
onnx/berta-onnx/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "frida-transformed"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13, <3.14"
7
+ dependencies = [
8
+ 'onnx == 1.20.1',
9
+ 'onnxruntime == 1.23.2',
10
+ 'onnxscript == 0.6.0',
11
+ 'onnx-safetensors == 1.5.0',
12
+ 'torch == 2.10.0',
13
+ 'torchvision == 0.25.0',
14
+ 'transformers == 4.57.3',
15
+ 'pycuda == 2026.1',
16
+ "ipykernel>=7.2.0",
17
+ "pip>=26.0.1",
18
+ "uv>=0.10.2",
19
+ "jupyter>=1.1.1",
20
+ "ipywidgets>=8.1.8",
21
+ "tqdm>=4.67.3",
22
+ "ipython>=9.10.0",
23
+ ]
24
+
25
+ [tool.uv.workspace]
26
+ members = [
27
+ "frida-transformed",
28
+ ]
safetensors_to_onnx.ipynb ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "metadata": {
5
+ "collapsed": true,
6
+ "ExecuteTime": {
7
+ "end_time": "2026-02-12T12:52:46.678786554Z",
8
+ "start_time": "2026-02-12T12:52:43.490350354Z"
9
+ }
10
+ },
11
+ "cell_type": "code",
12
+ "source": [
13
+ "import torch\n",
14
+ "from torch.export import Dim\n",
15
+ "from transformers import BertModel, AutoModel, AutoTokenizer\n",
16
+ "from pathlib import Path\n",
17
+ "import onnxruntime as ort\n",
18
+ "import numpy as np\n",
19
+ "from inspect import signature"
20
+ ],
21
+ "id": "2b3977272abf14d9",
22
+ "outputs": [],
23
+ "execution_count": 1
24
+ },
25
+ {
26
+ "metadata": {
27
+ "ExecuteTime": {
28
+ "end_time": "2026-02-12T12:52:46.726391124Z",
29
+ "start_time": "2026-02-12T12:52:46.691717774Z"
30
+ }
31
+ },
32
+ "cell_type": "code",
33
+ "source": [
34
+ "# MODEL_SOURCE_ID = \"sergeyzh/BERTA\"\n",
35
+ "MODEL_SOURCE_ID = \"../BERTA\"\n",
36
+ "MODEL_TARGET_PATH = Path(\"onnx/berta-onnx\")\n",
37
+ "ONNX_FILE_NAME = \"BERTA.onnx\"\n",
38
+ "\n",
39
+ "print(\"=\"*50)\n",
40
+ "print(f\"Подготовка директории: {MODEL_TARGET_PATH}\")\n",
41
+ "MODEL_TARGET_PATH.mkdir(parents=True, exist_ok=True)"
42
+ ],
43
+ "id": "494fc15203b0fb89",
44
+ "outputs": [
45
+ {
46
+ "name": "stdout",
47
+ "output_type": "stream",
48
+ "text": [
49
+ "==================================================\n",
50
+ "Подготовка директории: onnx/berta-onnx\n"
51
+ ]
52
+ }
53
+ ],
54
+ "execution_count": 2
55
+ },
56
+ {
57
+ "metadata": {
58
+ "ExecuteTime": {
59
+ "end_time": "2026-02-12T12:52:46.862603179Z",
60
+ "start_time": "2026-02-12T12:52:46.739714466Z"
61
+ }
62
+ },
63
+ "cell_type": "code",
64
+ "source": [
65
+ "# 1. Загружаем модель и токенизатор\n",
66
+ "print(f\"Загрузка модели и токенизатора из '{MODEL_SOURCE_ID}'...\")\n",
67
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_SOURCE_ID, repo_type=\"model\")\n",
68
+ "model = AutoModel.from_pretrained(MODEL_SOURCE_ID)\n",
69
+ "model.eval()\n",
70
+ "\n",
71
+ "# 2. Создаем тестовые входы\n",
72
+ "print(\"Создание тестовых входных данных...\")\n",
73
+ "test_texts = [\n",
74
+ " \"paraphrase: В Ярославской области разрешили работу бань, но без посетителей\",\n",
75
+ " \"search_query: Сколько программистов нужно, чтобы вкрутить лампочку?\",\n",
76
+ " \"categorize_entailment: Женщину доставили в больницу, за ее жизнь сейчас борются врачи.\"\n",
77
+ "]\n",
78
+ "\n",
79
+ "dummy_inputs = tokenizer(\n",
80
+ " test_texts,\n",
81
+ " max_length=512,\n",
82
+ " padding=\"max_length\",\n",
83
+ " truncation=True,\n",
84
+ " return_tensors=\"pt\"\n",
85
+ ")\n",
86
+ "print(dummy_inputs)"
87
+ ],
88
+ "id": "4f9f5febc6f07769",
89
+ "outputs": [
90
+ {
91
+ "name": "stdout",
92
+ "output_type": "stream",
93
+ "text": [
94
+ "Загрузка модели и токенизатора из '../BERTA'...\n",
95
+ "Создание тестовых входных данных...\n",
96
+ "{'input_ids': tensor([[ 2, 570, 11028, ..., 0, 0, 0],\n",
97
+ " [ 2, 3007, 67, ..., 0, 0, 0],\n",
98
+ " [ 2, 46369, 998, ..., 0, 0, 0]]), 'token_type_ids': tensor([[0, 0, 0, ..., 0, 0, 0],\n",
99
+ " [0, 0, 0, ..., 0, 0, 0],\n",
100
+ " [0, 0, 0, ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n",
101
+ " [1, 1, 1, ..., 0, 0, 0],\n",
102
+ " [1, 1, 1, ..., 0, 0, 0]])}\n"
103
+ ]
104
+ }
105
+ ],
106
+ "execution_count": 3
107
+ },
108
+ {
109
+ "metadata": {
110
+ "ExecuteTime": {
111
+ "end_time": "2026-02-12T12:52:46.899958136Z",
112
+ "start_time": "2026-02-12T12:52:46.868506089Z"
113
+ }
114
+ },
115
+ "cell_type": "code",
116
+ "source": [
117
+ "# print(model)\n",
118
+ "print(signature(model.forward))"
119
+ ],
120
+ "id": "8bdce4e5bc593383",
121
+ "outputs": [
122
+ {
123
+ "name": "stdout",
124
+ "output_type": "stream",
125
+ "text": [
126
+ "(input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[transformers.cache_utils.Cache] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None) -> Union[tuple[torch.Tensor], transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions]\n"
127
+ ]
128
+ }
129
+ ],
130
+ "execution_count": 4
131
+ },
132
+ {
133
+ "metadata": {
134
+ "ExecuteTime": {
135
+ "end_time": "2026-02-12T12:52:56.427369911Z",
136
+ "start_time": "2026-02-12T12:52:46.902043777Z"
137
+ }
138
+ },
139
+ "cell_type": "code",
140
+ "source": [
141
+ "# 3. Экспорт с двумя входами\n",
142
+ "onnx_model_path = MODEL_TARGET_PATH / ONNX_FILE_NAME\n",
143
+ "print(f\"Экспорт модели в ONNX формат: {onnx_model_path}\")\n",
144
+ "\n",
145
+ "# For dynamic_shapes\n",
146
+ "batch_size = Dim(\"batch_size\", min=1, max=64) # Optional: add min/max constraints\n",
147
+ "sequence_length = Dim(\"sequence_length\", min=2, max=512)\n",
148
+ "\n",
149
+ "# dynamic_shapes = {\n",
150
+ "# \"input_ids\": {0: batch_size, 1: sequence_length},\n",
151
+ "# \"attention_mask\": {0: batch_size, 1: sequence_length},\n",
152
+ "# \"last_hidden_state\": {0: batch_size, 1: sequence_length}\n",
153
+ "# }\n",
154
+ "\n",
155
+ "# In case of issues use dynamo_export instead of dynamo=True\n",
156
+ "torch.onnx.export(\n",
157
+ " model,\n",
158
+ " (dummy_inputs[\"input_ids\"], dummy_inputs[\"attention_mask\"]),\n",
159
+ " onnx_model_path.as_posix(),\n",
160
+ " input_names=[\"input_ids\", \"attention_mask\"],\n",
161
+ " output_names=[\"last_hidden_state\"],\n",
162
+ " opset_version=20, # Maybe update\n",
163
+ " dynamic_shapes = {\n",
164
+ " \"input_ids\": {0: batch_size, 1: sequence_length},\n",
165
+ " \"attention_mask\": {0: batch_size, 1: sequence_length}\n",
166
+ " },\n",
167
+ " verbose=True,\n",
168
+ " dynamo=True\n",
169
+ ")\n",
170
+ "# 4. Сохраняем токенизатор\n",
171
+ "print(f\"Сохранение токенизатора в '{MODEL_TARGET_PATH}'...\")\n",
172
+ "tokenizer.save_pretrained(MODEL_TARGET_PATH)\n",
173
+ "\n",
174
+ "print(\"Конвертация завершена успешно!\")"
175
+ ],
176
+ "id": "87d59bf71ed545dc",
177
+ "outputs": [
178
+ {
179
+ "name": "stdout",
180
+ "output_type": "stream",
181
+ "text": [
182
+ "Экспорт модели в ONNX формат: onnx/berta-onnx/BERTA.onnx\n"
183
+ ]
184
+ },
185
+ {
186
+ "name": "stderr",
187
+ "output_type": "stream",
188
+ "text": [
189
+ "W0212 14:52:47.799000 19280 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'input' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0, sampling_ratio: 'int' = -1, aligned: 'bool' = False). Treating as an Input.\n",
190
+ "W0212 14:52:47.800000 19280 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'boxes' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0, sampling_ratio: 'int' = -1, aligned: 'bool' = False). Treating as an Input.\n",
191
+ "W0212 14:52:47.801000 19280 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'input' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0). Treating as an Input.\n",
192
+ "W0212 14:52:47.802000 19280 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'boxes' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0). Treating as an Input.\n"
193
+ ]
194
+ },
195
+ {
196
+ "name": "stdout",
197
+ "output_type": "stream",
198
+ "text": [
199
+ "[torch.onnx] Obtain model graph for `BertModel([...]` with `torch.export.export(..., strict=False)`...\n",
200
+ "[torch.onnx] Obtain model graph for `BertModel([...]` with `torch.export.export(..., strict=False)`... ✅\n",
201
+ "[torch.onnx] Run decomposition...\n"
202
+ ]
203
+ },
204
+ {
205
+ "name": "stderr",
206
+ "output_type": "stream",
207
+ "text": [
208
+ "/home/lavrentiy/Projects/FRIDA-transformed/.venv/lib/python3.13/site-packages/torch/cuda/__init__.py:435: UserWarning: \n",
209
+ " Found GPU0 NVIDIA GeForce GTX 1060 6GB which is of cuda capability 6.1.\n",
210
+ " Minimum and Maximum cuda capability supported by this version of PyTorch is\n",
211
+ " (7.0) - (12.0)\n",
212
+ " \n",
213
+ " queued_call()\n",
214
+ "/home/lavrentiy/Projects/FRIDA-transformed/.venv/lib/python3.13/site-packages/torch/cuda/__init__.py:435: UserWarning: \n",
215
+ " Please install PyTorch with a following CUDA\n",
216
+ " configurations: 12.6 following instructions at\n",
217
+ " https://pytorch.org/get-started/locally/\n",
218
+ " \n",
219
+ " queued_call()\n",
220
+ "/home/lavrentiy/Projects/FRIDA-transformed/.venv/lib/python3.13/site-packages/torch/cuda/__init__.py:435: UserWarning: \n",
221
+ "NVIDIA GeForce GTX 1060 6GB with CUDA capability sm_61 is not compatible with the current PyTorch installation.\n",
222
+ "The current PyTorch install supports CUDA capabilities sm_70 sm_75 sm_80 sm_86 sm_90 sm_100 sm_120.\n",
223
+ "If you want to use the NVIDIA GeForce GTX 1060 6GB GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/\n",
224
+ "\n",
225
+ " queued_call()\n",
226
+ "/home/lavrentiy/.local/share/uv/python/cpython-3.13.11-linux-x86_64-gnu/lib/python3.13/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.\n",
227
+ " return cls.__new__(cls, *args)\n"
228
+ ]
229
+ },
230
+ {
231
+ "name": "stdout",
232
+ "output_type": "stream",
233
+ "text": [
234
+ "[torch.onnx] Run decomposition... ✅\n",
235
+ "[torch.onnx] Translate the graph into ONNX...\n",
236
+ "[torch.onnx] Translate the graph into ONNX... ✅\n"
237
+ ]
238
+ },
239
+ {
240
+ "name": "stderr",
241
+ "output_type": "stream",
242
+ "text": [
243
+ "/home/lavrentiy/Projects/FRIDA-transformed/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_onnx_program.py:460: UserWarning: # The axis name: batch_size will not be used, since it shares the same shape constraints with another axis: batch_size.\n",
244
+ " rename_mapping = _dynamic_shapes.create_rename_mapping(\n",
245
+ "/home/lavrentiy/Projects/FRIDA-transformed/.venv/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_onnx_program.py:460: UserWarning: # The axis name: sequence_length will not be used, since it shares the same shape constraints with another axis: sequence_length.\n",
246
+ " rename_mapping = _dynamic_shapes.create_rename_mapping(\n"
247
+ ]
248
+ },
249
+ {
250
+ "name": "stdout",
251
+ "output_type": "stream",
252
+ "text": [
253
+ "Applied 68 of general pattern rewrite rules.\n",
254
+ "Сохранение токенизатора в 'onnx/berta-onnx'...\n",
255
+ "Конвертация завершена успешно!\n"
256
+ ]
257
+ }
258
+ ],
259
+ "execution_count": 5
260
+ },
261
+ {
262
+ "metadata": {
263
+ "ExecuteTime": {
264
+ "end_time": "2026-02-12T12:52:56.931194388Z",
265
+ "start_time": "2026-02-12T12:52:56.428745759Z"
266
+ }
267
+ },
268
+ "cell_type": "code",
269
+ "source": [
270
+ "# 5. Тестирование и сравнение результатов\n",
271
+ "print(\"\\n\" + \"=\"*50)\n",
272
+ "print(\"ТЕСТИРОВАНИЕ РЕЗУЛЬТАТОВ\")\n",
273
+ "\n",
274
+ "def cls_pooling(hidden_state, attention_mask):\n",
275
+ " \"\"\"CLS pooling для получения эмбеддингов\"\"\"\n",
276
+ " return hidden_state[:, 0]\n",
277
+ "\n",
278
+ "def normalize_embeddings(embeddings):\n",
279
+ " \"\"\"Нормализация эмбеддингов\"\"\"\n",
280
+ " return embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)\n",
281
+ "\n",
282
+ "# Тест с оригинальной моделью\n",
283
+ "print(\"Тестирование оригинальной модели...\")\n",
284
+ "with torch.no_grad():\n",
285
+ " original_inputs = tokenizer(\n",
286
+ " test_texts,\n",
287
+ " max_length=512,\n",
288
+ " padding=True,\n",
289
+ " truncation=True,\n",
290
+ " return_tensors=\"pt\"\n",
291
+ " )\n",
292
+ " original_outputs = model(**original_inputs)\n",
293
+ " original_embeddings = cls_pooling(\n",
294
+ " original_outputs.last_hidden_state,\n",
295
+ " original_inputs[\"attention_mask\"]\n",
296
+ " )\n",
297
+ " original_embeddings = torch.nn.functional.normalize(original_embeddings, p=2, dim=1)\n",
298
+ "\n",
299
+ "# Тест с ONNX моделью\n",
300
+ "print(\"Тестирование ONNX модели...\")\n",
301
+ "onnx_session = ort.InferenceSession(onnx_model_path.as_posix())\n",
302
+ "\n",
303
+ "onnx_inputs = tokenizer(\n",
304
+ " test_texts,\n",
305
+ " max_length=512,\n",
306
+ " padding=True,\n",
307
+ " truncation=True,\n",
308
+ " return_tensors=\"np\"\n",
309
+ ")\n",
310
+ "\n",
311
+ "\n",
312
+ "onnx_inputs_int64 = {\n",
313
+ " \"input_ids\": onnx_inputs[\"input_ids\"].astype(np.int64),\n",
314
+ " \"attention_mask\": onnx_inputs[\"attention_mask\"].astype(np.int64)\n",
315
+ "}\n",
316
+ "\n",
317
+ "onnx_outputs = onnx_session.run(None, onnx_inputs_int64)[0]\n",
318
+ "\n",
319
+ "onnx_embeddings = onnx_outputs[:, 0]\n",
320
+ "onnx_embeddings = normalize_embeddings(onnx_embeddings)\n",
321
+ "\n",
322
+ "cosine_similarity = np.sum(original_embeddings.numpy() * onnx_embeddings, axis=1)\n",
323
+ "print(f\"\\nCosine similarity между оригинальной и ONNX моделью:\")\n",
324
+ "for i, sim in enumerate(cosine_similarity):\n",
325
+ " print(f\" Текст {i+1}: {sim:.6f}\")\n",
326
+ "print(f\"Средняя схожесть: {np.mean(cosine_similarity):.6f}\")\n",
327
+ "\n",
328
+ "print(\"\\n\" + \"=\"*50)\n",
329
+ "print(\"ГОТОВО! Модель успешно конвертирована и протестирована.\")\n",
330
+ "print(f\"Путь к модели: {MODEL_TARGET_PATH.resolve()}\")"
331
+ ],
332
+ "id": "91a5740805f8e829",
333
+ "outputs": [
334
+ {
335
+ "name": "stdout",
336
+ "output_type": "stream",
337
+ "text": [
338
+ "\n",
339
+ "==================================================\n",
340
+ "ТЕСТИРОВАНИЕ РЕЗУЛЬТАТОВ\n",
341
+ "Тестирование оригинальной модели...\n",
342
+ "Тестирование ONNX модели...\n",
343
+ "\n",
344
+ "Cosine similarity между оригинальной и ONNX моделью:\n",
345
+ " Текст 1: 1.000000\n",
346
+ " Текст 2: 1.000000\n",
347
+ " Текст 3: 1.000000\n",
348
+ "Средняя схожесть: 1.000000\n",
349
+ "\n",
350
+ "==================================================\n",
351
+ "ГОТОВО! Модель успешно конвертирована и протестирована.\n",
352
+ "Путь к модели: /home/lavrentiy/Projects/BERTA-transformed/onnx/berta-onnx\n"
353
+ ]
354
+ }
355
+ ],
356
+ "execution_count": 6
357
+ }
358
+ ],
359
+ "metadata": {
360
+ "kernelspec": {
361
+ "display_name": "Python 3",
362
+ "language": "python",
363
+ "name": "python3"
364
+ },
365
+ "language_info": {
366
+ "codemirror_mode": {
367
+ "name": "ipython",
368
+ "version": 2
369
+ },
370
+ "file_extension": ".py",
371
+ "mimetype": "text/x-python",
372
+ "name": "python",
373
+ "nbconvert_exporter": "python",
374
+ "pygments_lexer": "ipython2",
375
+ "version": "2.7.6"
376
+ }
377
+ },
378
+ "nbformat": 4,
379
+ "nbformat_minor": 5
380
+ }
safetensors_to_onnx.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.export import Dim
3
+ from transformers import T5EncoderModel, AutoTokenizer
4
+ from pathlib import Path
5
+ import onnxruntime as ort
6
+ import numpy as np
7
+
8
+
9
+ # MODEL_SOURCE_ID = "ai-forever/FRIDA"
10
+ MODEL_SOURCE_ID = "../FRIDA"
11
+ MODEL_TARGET_PATH = Path("onnx/frida-onnx")
12
+ ONNX_FILE_NAME = "FRIDA.onnx"
13
+
14
+ print("="*50)
15
+ print(f"Подготовка директории: {MODEL_TARGET_PATH}")
16
+ MODEL_TARGET_PATH.mkdir(parents=True, exist_ok=True)
17
+
18
+ # 1. Загружаем модель и токенизатор
19
+ print(f"Загрузка модели и токенизатора из '{MODEL_SOURCE_ID}'...")
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_SOURCE_ID, repo_type="model")
21
+ model = T5EncoderModel.from_pretrained(MODEL_SOURCE_ID)
22
+ model.eval()
23
+
24
+ # 2. Создаем тестовые входы
25
+ print("Создание тестовых входных данных...")
26
+ test_texts = [
27
+ "paraphrase: В Ярославской области разрешили работу бань, но без посетителей",
28
+ "search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
29
+ "categorize_entailment: Женщину доставили в больницу, за ее жизнь сейчас борются врачи."
30
+ ]
31
+
32
+ dummy_inputs = tokenizer(
33
+ test_texts,
34
+ max_length=512,
35
+ padding="max_length",
36
+ truncation=True,
37
+ return_tensors="pt"
38
+ )
39
+
40
+ # 3. Экспорт с двумя входами
41
+ onnx_model_path = MODEL_TARGET_PATH / ONNX_FILE_NAME
42
+ print(f"Экспорт модели в ONNX формат: {onnx_model_path}")
43
+
44
+ # For dynamic_shapes
45
+ batch_size = Dim("batch_size", min=1, max=64) # Optional: add min/max constraints
46
+ sequence_length = Dim("sequence_length", min=2, max=512)
47
+
48
+ # dynamic_shapes = {
49
+ # "input_ids": {0: batch_size, 1: sequence_length},
50
+ # "attention_mask": {0: batch_size, 1: sequence_length},
51
+ # "last_hidden_state": {0: batch_size, 1: sequence_length}
52
+ # }
53
+
54
+ # In case of issues use dynamo_export instead of dynamo=True
55
+ torch.onnx.export(
56
+ model,
57
+ (dummy_inputs["input_ids"], dummy_inputs["attention_mask"]),
58
+ onnx_model_path.as_posix(),
59
+ input_names=["input_ids", "attention_mask"],
60
+ output_names=["last_hidden_state"],
61
+ opset_version=20, # Maybe update
62
+ dynamic_shapes = {
63
+ "input_ids": {0: batch_size, 1: sequence_length},
64
+ "attention_mask": {0: batch_size, 1: sequence_length}
65
+ },
66
+ verbose=False,
67
+ dynamo=True
68
+ )
69
+
70
+ # 4. Сохраняем токенизатор
71
+ print(f"Сохранение токенизатора в '{MODEL_TARGET_PATH}'...")
72
+ tokenizer.save_pretrained(MODEL_TARGET_PATH)
73
+
74
+ print("Конвертация завершена успешно!")
75
+
76
+ # 5. Тестирование и сравнение результатов
77
+ print("\n" + "="*50)
78
+ print("ТЕСТИРОВАНИЕ РЕЗУЛЬТАТОВ")
79
+
80
+ def cls_pooling(hidden_state, attention_mask):
81
+ """CLS pooling для получения эмбеддингов"""
82
+ return hidden_state[:, 0]
83
+
84
+ def normalize_embeddings(embeddings):
85
+ """Нормализация эмбеддингов"""
86
+ return embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
87
+
88
+ # Тест с оригинальной моделью
89
+ print("Тестирование оригинальной модели...")
90
+ with torch.no_grad():
91
+ original_inputs = tokenizer(
92
+ test_texts,
93
+ max_length=512,
94
+ padding=True,
95
+ truncation=True,
96
+ return_tensors="pt"
97
+ )
98
+ original_outputs = model(**original_inputs)
99
+ original_embeddings = cls_pooling(
100
+ original_outputs.last_hidden_state,
101
+ original_inputs["attention_mask"]
102
+ )
103
+ original_embeddings = torch.nn.functional.normalize(original_embeddings, p=2, dim=1)
104
+
105
+ # Тест с ONNX моделью
106
+ print("Тестирование ONNX модели...")
107
+ onnx_session = ort.InferenceSession(onnx_model_path.as_posix())
108
+
109
+ onnx_inputs = tokenizer(
110
+ test_texts,
111
+ max_length=512,
112
+ padding=True,
113
+ truncation=True,
114
+ return_tensors="np"
115
+ )
116
+
117
+
118
+ onnx_inputs_int64 = {
119
+ "input_ids": onnx_inputs["input_ids"].astype(np.int64),
120
+ "attention_mask": onnx_inputs["attention_mask"].astype(np.int64)
121
+ }
122
+
123
+ onnx_outputs = onnx_session.run(None, onnx_inputs_int64)[0]
124
+
125
+ onnx_embeddings = onnx_outputs[:, 0]
126
+ onnx_embeddings = normalize_embeddings(onnx_embeddings)
127
+
128
+ cosine_similarity = np.sum(original_embeddings.numpy() * onnx_embeddings, axis=1)
129
+ print(f"\nCosine similarity между оригинальной и ONNX моделью:")
130
+ for i, sim in enumerate(cosine_similarity):
131
+ print(f" Текст {i+1}: {sim:.6f}")
132
+ print(f"Средняя схожесть: {np.mean(cosine_similarity):.6f}")
133
+
134
+ print("\n" + "="*50)
135
+ print("ГОТОВО! Модель успешно конвертирована и протестирована.")
136
+ print(f"Путь к модели: {MODEL_TARGET_PATH.resolve()}")