geologist387 commited on
Commit
8b5144a
·
1 Parent(s): eef0ae4

Added signature for viewing model inputs

Browse files
Files changed (2) hide show
  1. pyproject.toml +13 -7
  2. safetensors_to_onnx.ipynb +257 -0
pyproject.toml CHANGED
@@ -5,15 +5,21 @@ description = "Add your description here"
5
  readme = "README.md"
6
  requires-python = ">=3.13, <3.14"
7
  dependencies = [
8
- 'onnx == 1.20.0',
9
  'onnxruntime == 1.23.2',
10
- 'onnxscript == 0.5.7',
11
- 'onnx-safetensors == 1.2.0',
12
- 'torch == 2.9.1',
13
- 'torchvision == 0.24.1',
14
  'transformers == 4.57.3',
15
- 'tensorrt == 10.14.1.48.post1',
16
- 'pycuda == 2025.1.2'
 
 
 
 
 
 
17
  ]
18
 
19
  [tool.uv.workspace]
 
5
  readme = "README.md"
6
  requires-python = ">=3.13, <3.14"
7
  dependencies = [
8
+ 'onnx == 1.20.1',
9
  'onnxruntime == 1.23.2',
10
+ 'onnxscript == 0.6.0',
11
+ 'onnx-safetensors == 1.5.0',
12
+ 'torch == 2.10.0',
13
+ 'torchvision == 0.25.0',
14
  'transformers == 4.57.3',
15
+ 'pycuda == 2026.1',
16
+ "ipykernel>=7.2.0",
17
+ "pip>=26.0.1",
18
+ "uv>=0.10.2",
19
+ "jupyter>=1.1.1",
20
+ "ipywidgets>=8.1.8",
21
+ "tqdm>=4.67.3",
22
+ "ipython>=9.10.0",
23
  ]
24
 
25
  [tool.uv.workspace]
safetensors_to_onnx.ipynb ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "id": "initial_id",
6
+ "metadata": {
7
+ "collapsed": true,
8
+ "ExecuteTime": {
9
+ "end_time": "2026-02-12T12:37:32.166521648Z",
10
+ "start_time": "2026-02-12T12:37:32.138056109Z"
11
+ }
12
+ },
13
+ "source": [
14
+ "import torch\n",
15
+ "from torch.export import Dim\n",
16
+ "from transformers import T5EncoderModel, AutoTokenizer\n",
17
+ "from pathlib import Path\n",
18
+ "import onnxruntime as ort\n",
19
+ "import numpy as np\n",
20
+ "from inspect import signature"
21
+ ],
22
+ "outputs": [],
23
+ "execution_count": 5
24
+ },
25
+ {
26
+ "metadata": {
27
+ "ExecuteTime": {
28
+ "end_time": "2026-02-12T12:37:00.482648074Z",
29
+ "start_time": "2026-02-12T12:37:00.118707317Z"
30
+ }
31
+ },
32
+ "cell_type": "code",
33
+ "source": [
34
+ "# MODEL_SOURCE_ID = \"ai-forever/FRIDA\"\n",
35
+ "MODEL_SOURCE_ID = \"../FRIDA\"\n",
36
+ "MODEL_TARGET_PATH = Path(\"onnx/frida-onnx\")\n",
37
+ "ONNX_FILE_NAME = \"FRIDA.onnx\"\n",
38
+ "\n",
39
+ "print(\"=\"*50)\n",
40
+ "print(f\"Подготовка директории: {MODEL_TARGET_PATH}\")\n",
41
+ "MODEL_TARGET_PATH.mkdir(parents=True, exist_ok=True)"
42
+ ],
43
+ "id": "ef5e190f02e042b6",
44
+ "outputs": [
45
+ {
46
+ "name": "stdout",
47
+ "output_type": "stream",
48
+ "text": [
49
+ "==================================================\n",
50
+ "Подготовка директории: onnx/frida-onnx\n"
51
+ ]
52
+ }
53
+ ],
54
+ "execution_count": 2
55
+ },
56
+ {
57
+ "metadata": {
58
+ "ExecuteTime": {
59
+ "end_time": "2026-02-12T12:37:17.778488452Z",
60
+ "start_time": "2026-02-12T12:37:16.890360137Z"
61
+ }
62
+ },
63
+ "cell_type": "code",
64
+ "source": [
65
+ "# 1. Загружаем модель и токенизатор\n",
66
+ "print(f\"Загрузка модели и токенизатора из '{MODEL_SOURCE_ID}'...\")\n",
67
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_SOURCE_ID, repo_type=\"model\")\n",
68
+ "model = T5EncoderModel.from_pretrained(MODEL_SOURCE_ID)\n",
69
+ "model.eval()\n",
70
+ "\n",
71
+ "# 2. Создаем тестовые входы\n",
72
+ "print(\"Создание тестовых входных данных...\")\n",
73
+ "test_texts = [\n",
74
+ " \"paraphrase: В Ярославской области разрешили работу бань, но без посетителей\",\n",
75
+ " \"search_query: Сколько программистов нужно, чтобы вкрутить лампочку?\",\n",
76
+ " \"categorize_entailment: Женщину доставили в больницу, за ее жизнь сейчас борются врачи.\"\n",
77
+ "]\n",
78
+ "\n",
79
+ "dummy_inputs = tokenizer(\n",
80
+ " test_texts,\n",
81
+ " max_length=512,\n",
82
+ " padding=\"max_length\",\n",
83
+ " truncation=True,\n",
84
+ " return_tensors=\"pt\"\n",
85
+ ")"
86
+ ],
87
+ "id": "d2913ab82e279832",
88
+ "outputs": [
89
+ {
90
+ "name": "stdout",
91
+ "output_type": "stream",
92
+ "text": [
93
+ "Загрузка модели и токенизатора из '../FRIDA'...\n",
94
+ "Создание тестовых входных данных...\n"
95
+ ]
96
+ }
97
+ ],
98
+ "execution_count": 3
99
+ },
100
+ {
101
+ "metadata": {
102
+ "ExecuteTime": {
103
+ "end_time": "2026-02-12T12:37:34.830442932Z",
104
+ "start_time": "2026-02-12T12:37:34.719042026Z"
105
+ }
106
+ },
107
+ "cell_type": "code",
108
+ "source": "print(signature(model.forward))",
109
+ "id": "e55cf99269a639d2",
110
+ "outputs": [
111
+ {
112
+ "name": "stdout",
113
+ "output_type": "stream",
114
+ "text": [
115
+ "(input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) -> Union[tuple[torch.FloatTensor], transformers.modeling_outputs.BaseModelOutput]\n"
116
+ ]
117
+ }
118
+ ],
119
+ "execution_count": 6
120
+ },
121
+ {
122
+ "metadata": {},
123
+ "cell_type": "code",
124
+ "source": [
125
+ "# 3. Экспорт с двумя входами\n",
126
+ "onnx_model_path = MODEL_TARGET_PATH / ONNX_FILE_NAME\n",
127
+ "print(f\"Экспорт модели в ONNX формат: {onnx_model_path}\")\n",
128
+ "\n",
129
+ "# For dynamic_shapes\n",
130
+ "batch_size = Dim(\"batch_size\", min=1, max=64) # Optional: add min/max constraints\n",
131
+ "sequence_length = Dim(\"sequence_length\", min=2, max=512)\n",
132
+ "\n",
133
+ "# dynamic_shapes = {\n",
134
+ "# \"input_ids\": {0: batch_size, 1: sequence_length},\n",
135
+ "# \"attention_mask\": {0: batch_size, 1: sequence_length},\n",
136
+ "# \"last_hidden_state\": {0: batch_size, 1: sequence_length}\n",
137
+ "# }\n",
138
+ "\n",
139
+ "# In case of issues use dynamo_export instead of dynamo=True\n",
140
+ "torch.onnx.export(\n",
141
+ " model,\n",
142
+ " (dummy_inputs[\"input_ids\"], dummy_inputs[\"attention_mask\"]),\n",
143
+ " onnx_model_path.as_posix(),\n",
144
+ " input_names=[\"input_ids\", \"attention_mask\"],\n",
145
+ " output_names=[\"last_hidden_state\"],\n",
146
+ " opset_version=20, # Maybe update\n",
147
+ " dynamic_shapes = {\n",
148
+ " \"input_ids\": {0: batch_size, 1: sequence_length},\n",
149
+ " \"attention_mask\": {0: batch_size, 1: sequence_length}\n",
150
+ " },\n",
151
+ " verbose=False,\n",
152
+ " dynamo=True\n",
153
+ ")\n",
154
+ "\n",
155
+ "# 4. Сохраняем токенизатор\n",
156
+ "print(f\"Сохранение токенизатора в '{MODEL_TARGET_PATH}'...\")\n",
157
+ "tokenizer.save_pretrained(MODEL_TARGET_PATH)\n",
158
+ "\n",
159
+ "print(\"Конвертация завершена успешно!\")"
160
+ ],
161
+ "id": "48bfef4b286ae47b",
162
+ "outputs": [],
163
+ "execution_count": null
164
+ },
165
+ {
166
+ "metadata": {},
167
+ "cell_type": "code",
168
+ "source": [
169
+ "# 5. Тестирование и сравнение результатов\n",
170
+ "print(\"\\n\" + \"=\"*50)\n",
171
+ "print(\"ТЕСТИРОВАНИЕ РЕЗУЛЬТАТОВ\")\n",
172
+ "\n",
173
+ "def cls_pooling(hidden_state, attention_mask):\n",
174
+ " \"\"\"CLS pooling для получения эмбеддингов\"\"\"\n",
175
+ " return hidden_state[:, 0]\n",
176
+ "\n",
177
+ "def normalize_embeddings(embeddings):\n",
178
+ " \"\"\"Нормализация эмбеддингов\"\"\"\n",
179
+ " return embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)\n",
180
+ "\n",
181
+ "# Тест с оригинальной моделью\n",
182
+ "print(\"Тестирование оригинальной модели...\")\n",
183
+ "with torch.no_grad():\n",
184
+ " original_inputs = tokenizer(\n",
185
+ " test_texts,\n",
186
+ " max_length=512,\n",
187
+ " padding=True,\n",
188
+ " truncation=True,\n",
189
+ " return_tensors=\"pt\"\n",
190
+ " )\n",
191
+ " original_outputs = model(**original_inputs)\n",
192
+ " original_embeddings = cls_pooling(\n",
193
+ " original_outputs.last_hidden_state,\n",
194
+ " original_inputs[\"attention_mask\"]\n",
195
+ " )\n",
196
+ " original_embeddings = torch.nn.functional.normalize(original_embeddings, p=2, dim=1)\n",
197
+ "\n",
198
+ "# Тест с ONNX моделью\n",
199
+ "print(\"Тестирование ONNX модели...\")\n",
200
+ "onnx_session = ort.InferenceSession(onnx_model_path.as_posix())\n",
201
+ "\n",
202
+ "onnx_inputs = tokenizer(\n",
203
+ " test_texts,\n",
204
+ " max_length=512,\n",
205
+ " padding=True,\n",
206
+ " truncation=True,\n",
207
+ " return_tensors=\"np\"\n",
208
+ ")\n",
209
+ "\n",
210
+ "\n",
211
+ "onnx_inputs_int64 = {\n",
212
+ " \"input_ids\": onnx_inputs[\"input_ids\"].astype(np.int64),\n",
213
+ " \"attention_mask\": onnx_inputs[\"attention_mask\"].astype(np.int64)\n",
214
+ "}\n",
215
+ "\n",
216
+ "onnx_outputs = onnx_session.run(None, onnx_inputs_int64)[0]\n",
217
+ "\n",
218
+ "onnx_embeddings = onnx_outputs[:, 0]\n",
219
+ "onnx_embeddings = normalize_embeddings(onnx_embeddings)\n",
220
+ "\n",
221
+ "cosine_similarity = np.sum(original_embeddings.numpy() * onnx_embeddings, axis=1)\n",
222
+ "print(f\"\\nCosine similarity между оригинальной и ONNX моделью:\")\n",
223
+ "for i, sim in enumerate(cosine_similarity):\n",
224
+ " print(f\" Текст {i+1}: {sim:.6f}\")\n",
225
+ "print(f\"Средняя схожесть: {np.mean(cosine_similarity):.6f}\")\n",
226
+ "\n",
227
+ "print(\"\\n\" + \"=\"*50)\n",
228
+ "print(\"ГОТОВО! Модель успешно конвертирована и протестирована.\")\n",
229
+ "print(f\"Путь к модели: {MODEL_TARGET_PATH.resolve()}\")"
230
+ ],
231
+ "id": "e488535f18210818",
232
+ "outputs": [],
233
+ "execution_count": null
234
+ }
235
+ ],
236
+ "metadata": {
237
+ "kernelspec": {
238
+ "display_name": "Python 3",
239
+ "language": "python",
240
+ "name": "python3"
241
+ },
242
+ "language_info": {
243
+ "codemirror_mode": {
244
+ "name": "ipython",
245
+ "version": 2
246
+ },
247
+ "file_extension": ".py",
248
+ "mimetype": "text/x-python",
249
+ "name": "python",
250
+ "nbconvert_exporter": "python",
251
+ "pygments_lexer": "ipython2",
252
+ "version": "2.7.6"
253
+ }
254
+ },
255
+ "nbformat": 4,
256
+ "nbformat_minor": 5
257
+ }