File size: 48,179 Bytes
efec384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "MODEL_NAME = \"/workspace/model\"\n",
    "model_token = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import torch\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_token)\n",
    "tokenizer.pad_token = tokenizer.eos_token  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "json_path = \"final_Graph.json\"\n",
    "with open(json_path, \"r\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "test_data = data[0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "ROLE_TOKENS = {\n",
    "    \"human\": \"<|User|>\",     \n",
    "    \"gpt\": \"<|Assistant|>\",   \n",
    "}\n",
    "GRAPH_LENGTH = 512\n",
    "max_seq_length = 1100 + GRAPH_LENGTH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "conversations = test_data.get(\"conversations\")\n",
    "embeddings = test_data.get(\"embedding\") \n",
    "\n",
    "graph_embedding = torch.tensor(embeddings, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'What are the signal definitions in the Verilog code for the calculator module, and what are their purposes?'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "question1 = conversations[0][\"value\"].replace(\"<image>\", \"\").strip()\n",
    "question1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import torch\n",
    "import os\n",
    "from transformers import AutoTokenizer\n",
    "# tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
    "from torch.utils.data import Dataset\n",
    "from transformers import AutoModelForCausalLM\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class GraphAwareLM(AutoModelForCausalLM):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.model = AutoModelForCausalLM.from_config(config)\n",
    "        \n",
    "        # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
    "        self.graph_proj = nn.Linear(512, config.hidden_size)\n",
    "\n",
    "    def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
    "        \"\"\"\n",
    "        `graph_embedding` 形状: (batch_size, 512)\n",
    "        `input_ids` 形状: (batch_size, seq_len)\n",
    "        \"\"\"\n",
    "        # ✅ 获取 token embedding\n",
    "        inputs_embeds = self.model.get_input_embeddings()(input_ids)  # (batch_size, seq_len, hidden_size)\n",
    "\n",
    "        # ✅ 变换 graph embedding 到 hidden_size\n",
    "        graph_embedding_token = self.graph_proj(graph_embedding.squeeze(0))  # (batch_size, hidden_size)\n",
    "\n",
    "        # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
    "        graph_embedding_token = graph_embedding_token.unsqueeze(1)  # (batch_size, 1, hidden_size)\n",
    "        inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1)  # (batch_size, seq_len+1, hidden_size)\n",
    "\n",
    "        # ✅ 调整 attention mask\n",
    "        if attention_mask is not None:\n",
    "            graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
    "            attention_mask = torch.cat([graph_mask, attention_mask], dim=1)  # (batch_size, seq_len+1)\n",
    "\n",
    "        # ✅ 传入模型\n",
    "        outputs = self.model(\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            attention_mask=attention_mask,\n",
    "            labels=labels,\n",
    "        )\n",
    "\n",
    "        return outputs\n",
    "\n",
    "    @classmethod\n",
    "    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
    "        model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
    "        model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
    "        return model\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model = GraphAwareLM.from_pretrained(MODEL_NAME).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-2.4214, -0.5552,  1.0389, -1.3428, -0.1341,  0.6100, -0.4200, -1.8584,\n",
       "         -0.2880, -0.4779,  0.3452, -0.8934, -0.9216,  0.5600,  0.2474, -0.9009,\n",
       "         -1.0995,  0.6065,  1.7662, -1.2281,  0.0000, -1.9196,  0.1920, -1.2770,\n",
       "         -0.6918, -1.3762, -0.7639, -0.1023,  2.5149,  1.1990, -0.2678, -0.7488,\n",
       "         -0.0000,  0.9108,  0.2010, -0.2639,  0.5023, -0.8752,  0.2083,  0.5740,\n",
       "          0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750,  1.6133,\n",
       "         -2.3233,  0.3174,  0.0000,  0.5769,  0.3558,  0.2234, -0.0666, -0.6310,\n",
       "         -0.3533,  0.9497, -0.9576,  0.1615, -0.0460, -1.1686,  1.4337, -1.2952,\n",
       "         -1.1095,  0.5081, -1.9626, -0.3278,  0.7837, -2.4616,  0.3936, -0.3157,\n",
       "         -1.6531, -0.0708, -0.6630,  0.4285,  0.1360, -0.7986, -0.1449,  0.0000,\n",
       "          0.9076,  0.7794,  0.6391,  0.9840,  0.2970,  1.5463,  1.1554, -0.5432,\n",
       "          0.7202,  0.0000, -0.2380,  0.0422,  0.0000,  0.4296,  0.2068,  0.3330,\n",
       "         -0.5888,  0.0000,  1.0656, -0.2724,  0.7562, -0.6863, -1.6948, -0.1634,\n",
       "          1.8262,  1.4235,  0.9178, -0.7475, -0.2682,  0.5534,  1.5643, -0.9898,\n",
       "         -0.2911,  1.3752,  0.6331, -0.1162,  1.7250,  0.8486, -0.0000, -1.6454,\n",
       "         -4.2099, -0.1101,  0.9528, -0.1335,  0.1057,  0.2624,  2.4600,  1.2772,\n",
       "         -3.6113, -1.6540,  1.7807, -0.5077,  0.4537,  1.0987, -0.0713,  0.1391,\n",
       "         -0.0000, -1.3129,  0.5611, -0.3687, -0.7690,  0.0190,  0.9332, -0.4274,\n",
       "         -0.4125, -0.6608,  0.4810, -0.6759, -0.8501,  0.0000, -1.6998,  0.3269,\n",
       "          0.0334, -0.8513, -0.8695, -0.2957, -2.1983,  1.1621,  0.1864,  0.6089,\n",
       "          0.4840, -0.6849,  0.2127,  0.7035, -2.9177,  2.2954, -2.0283, -2.1883,\n",
       "         -0.0000,  0.1591,  1.3046, -0.0000,  0.2811,  0.0935, -1.0028,  0.8179,\n",
       "          1.5387,  0.5271,  0.2195, -0.0882, -1.3943,  0.8263,  0.7164,  0.6240,\n",
       "          0.7027, -0.5830, -1.2238, -0.0000,  0.5721,  0.0000,  0.3103,  0.7294,\n",
       "         -0.0224,  2.8884, -0.0000, -0.0000,  2.1562, -0.6177,  1.5242, -0.0000,\n",
       "         -0.9023, -0.0000,  1.9196, -0.9594, -0.7334,  0.6636,  0.0000,  0.5613,\n",
       "         -0.3294,  1.1782, -0.8789,  1.6285,  0.3845,  0.1210,  1.3321,  0.5566,\n",
       "         -0.4729,  1.9552, -0.6409,  1.1379, -0.0000,  1.2146, -0.7578, -0.3764,\n",
       "         -0.0823, -1.7541, -0.1362, -0.1631, -0.6794,  1.2874,  0.2402,  0.0000,\n",
       "          2.3540, -0.5574, -0.9901,  0.3435,  0.6318, -0.3071, -0.6270, -1.8417,\n",
       "         -1.9213, -0.4928,  0.1969, -1.2195, -0.1594, -1.1694,  1.9461,  1.4360,\n",
       "         -0.4050,  1.3495,  0.3053, -0.3500, -0.1546, -0.4096,  0.8011, -0.5379,\n",
       "         -0.1322,  0.0000,  1.7025, -0.0000, -0.7611,  1.4174, -1.0466, -0.8641,\n",
       "          0.3074, -0.9910,  0.0000,  1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
       "         -0.4996, -0.3315,  1.6280,  0.1051,  0.3570,  2.4021, -0.0249,  0.8169,\n",
       "         -0.4497, -1.4486, -0.0000, -0.7351, -0.3337,  0.2480, -0.5413,  2.2289,\n",
       "          1.6903,  0.7866,  0.6164,  0.8920, -1.1745, -0.3534, -0.4512,  0.0000,\n",
       "         -0.3795, -1.2503, -0.5114,  1.6374,  1.3271,  1.8410,  0.1040,  0.9731,\n",
       "         -0.3357,  2.4072, -0.0000,  1.9666, -0.5907,  1.0771,  1.6236, -0.9991,\n",
       "         -0.0282,  0.6689, -1.0429,  0.9279,  0.0000, -0.1722, -1.0940, -1.1756,\n",
       "         -0.2457, -1.1142, -1.5693,  1.7408,  1.8951, -1.5109, -0.3783, -0.4719,\n",
       "         -0.7410, -0.2575,  0.0000, -0.8207, -0.6377, -1.2434,  0.4213, -2.1689,\n",
       "          1.1191,  0.8991, -0.7343, -0.0000,  0.1287, -1.0638, -1.3629, -0.0916,\n",
       "          0.6016, -1.2285,  2.1858, -0.1274, -0.1246,  0.8666, -0.1599, -0.9024,\n",
       "         -0.6486,  0.9323,  1.4422, -0.7030,  1.6400,  1.2095,  0.9178, -0.6975,\n",
       "          1.5239, -1.8692, -2.4644, -0.0000,  1.3411, -0.0351,  1.9389,  1.3991,\n",
       "         -1.0556, -0.8072,  0.9237,  0.8799,  0.2778, -0.8607,  0.4810, -0.0000,\n",
       "          0.8293,  0.0735,  2.2176, -0.0000, -0.4048,  0.8768, -1.4589, -2.3772,\n",
       "         -0.5785,  0.7544, -1.3414,  0.7273, -1.4420,  2.0120, -0.0846, -1.0264,\n",
       "         -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346,  2.7815,  0.3414,\n",
       "          2.6266,  0.2384,  2.0168,  0.6710,  0.9409, -0.3611,  1.6438, -0.0000,\n",
       "         -0.8750, -0.1610,  0.8060, -1.5453,  0.3108, -0.6887,  0.0000,  0.3937,\n",
       "          0.2050, -0.7704,  1.1102,  0.1719, -0.4513, -0.1844,  0.7308, -2.4639,\n",
       "         -0.1578, -0.5711, -0.4696, -0.8899,  0.0929, -0.2267,  0.1619,  0.7937,\n",
       "         -0.3767,  0.2024,  0.3893, -0.7677,  1.5729, -0.6239, -0.0000,  0.8411,\n",
       "          0.6361, -1.1110, -1.2833,  1.0356, -0.9941,  0.5842, -0.7817, -0.5730,\n",
       "          0.2732, -0.6890, -0.0000, -0.0087,  1.3772,  0.3003,  0.0000,  0.8828,\n",
       "         -1.7060, -0.9499,  0.0000,  1.2618, -0.1124,  0.9352,  0.5854,  1.1139,\n",
       "          0.1583,  3.3464, -0.4027,  0.5860, -0.8730, -0.0163, -0.7023,  2.1778,\n",
       "         -3.2313,  1.5753,  0.8494, -1.3516, -2.2013, -1.6432,  0.2581,  0.2197,\n",
       "         -0.7742, -0.6365, -2.4008,  1.4902,  0.3697, -0.2428,  0.0000, -0.6978,\n",
       "         -0.0000,  0.7576,  1.7998,  0.0000, -0.8300, -1.0503,  0.4118,  1.4737,\n",
       "         -1.0162, -1.1784, -0.3985,  0.1699, -0.0000, -0.6951, -1.5820,  1.2909,\n",
       "          1.7528,  0.1409, -1.3121,  1.7415,  0.5114, -1.7321,  2.0781,  0.5635]],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "# ✅ 加载分词器\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_token)\n",
    "\n",
    "# ✅ 输入文本\n",
    "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
    "\n",
    "graph_embedding.to(device)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[14], line 6\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_new_tokens):\n\u001b[1;32m      4\u001b[0m     \u001b[38;5;66;03m# ✅ 计算 logits 并进行生成\u001b[39;00m\n\u001b[1;32m      5\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 6\u001b[0m         outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m      7\u001b[0m \u001b[43m            \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgenerated_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m        \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m      8\u001b[0m \u001b[43m            \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mattention_mask\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m      9\u001b[0m \u001b[43m            \u001b[49m\u001b[43mgraph_embedding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgraph_embedding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m      \u001b[49m\u001b[38;5;66;43;03m# (batch_size, 512)\u001b[39;49;00m\n\u001b[1;32m     10\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     13\u001b[0m     logits \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlogits[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, :]  \u001b[38;5;66;03m# 取最后一个 token 的 logits\u001b[39;00m\n\u001b[1;32m     14\u001b[0m     next_token \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39margmax(logits, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)  \u001b[38;5;66;03m# 贪心解码\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/utils/deprecation.py:172\u001b[0m, in \u001b[0;36mdeprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    168\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m minimum_action \u001b[38;5;129;01min\u001b[39;00m (Action\u001b[38;5;241m.\u001b[39mNOTIFY, Action\u001b[38;5;241m.\u001b[39mNOTIFY_ALWAYS) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[1;32m    169\u001b[0m     \u001b[38;5;66;03m# DeprecationWarning is ignored by default, so we use FutureWarning instead\u001b[39;00m\n\u001b[1;32m    170\u001b[0m     warnings\u001b[38;5;241m.\u001b[39mwarn(message, \u001b[38;5;167;01mFutureWarning\u001b[39;00m, stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m--> 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:856\u001b[0m, in \u001b[0;36mQwen2ForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, **kwargs)\u001b[0m\n\u001b[1;32m    853\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m    855\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m--> 856\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    857\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    858\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    859\u001b[0m \u001b[43m    \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    860\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    861\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    862\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    863\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    864\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    865\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    866\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    867\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    868\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    870\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    871\u001b[0m \u001b[38;5;66;03m# Only compute necessary logits, and do not upcast them to float if we are not computing the loss\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:579\u001b[0m, in \u001b[0;36mQwen2Model.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)\u001b[0m\n\u001b[1;32m    567\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m    568\u001b[0m         decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m    569\u001b[0m         hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    576\u001b[0m         position_embeddings,\n\u001b[1;32m    577\u001b[0m     )\n\u001b[1;32m    578\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 579\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    580\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    581\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    582\u001b[0m \u001b[43m        \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    583\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    584\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    585\u001b[0m \u001b[43m        \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    586\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    587\u001b[0m \u001b[43m        \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    588\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    589\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    591\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    593\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:260\u001b[0m, in \u001b[0;36mQwen2DecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m    257\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[1;32m    259\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[0;32m--> 260\u001b[0m hidden_states, self_attn_weights \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    261\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    262\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    263\u001b[0m \u001b[43m    \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    264\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    265\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    266\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    267\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    268\u001b[0m \u001b[43m    \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    269\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    270\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    271\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m    273\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:192\u001b[0m, in \u001b[0;36mQwen2Attention.forward\u001b[0;34m(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m    189\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    190\u001b[0m         attention_interface \u001b[38;5;241m=\u001b[39m ALL_ATTENTION_FUNCTIONS[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39m_attn_implementation]\n\u001b[0;32m--> 192\u001b[0m attn_output, attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattention_interface\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    193\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    194\u001b[0m \u001b[43m    \u001b[49m\u001b[43mquery_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    195\u001b[0m \u001b[43m    \u001b[49m\u001b[43mkey_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    196\u001b[0m \u001b[43m    \u001b[49m\u001b[43mvalue_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    197\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    198\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdropout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention_dropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    199\u001b[0m \u001b[43m    \u001b[49m\u001b[43mscaling\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaling\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    200\u001b[0m \u001b[43m    \u001b[49m\u001b[43msliding_window\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msliding_window\u001b[49m\u001b[43m,\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# main diff with Llama\u001b[39;49;00m\n\u001b[1;32m    201\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    202\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    204\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m attn_output\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m*\u001b[39minput_shape, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcontiguous()\n\u001b[1;32m    205\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mo_proj(attn_output)\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:123\u001b[0m, in \u001b[0;36meager_attention_forward\u001b[0;34m(module, query, key, value, attention_mask, scaling, dropout, **kwargs)\u001b[0m\n\u001b[1;32m    121\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    122\u001b[0m     causal_mask \u001b[38;5;241m=\u001b[39m attention_mask[:, :, :, : key_states\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m]]\n\u001b[0;32m--> 123\u001b[0m     attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattn_weights\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mcausal_mask\u001b[49m\n\u001b[1;32m    125\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39msoftmax(attn_weights, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32)\u001b[38;5;241m.\u001b[39mto(query\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m    126\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mdropout(attn_weights, p\u001b[38;5;241m=\u001b[39mdropout, training\u001b[38;5;241m=\u001b[39mmodule\u001b[38;5;241m.\u001b[39mtraining)\n",
      "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3"
     ]
    }
   ],
   "source": [
    "\n",
    "generated_ids = inputs[\"input_ids\"]\n",
    "max_new_tokens = 1024\n",
    "for _ in range(max_new_tokens):\n",
    "    # ✅ 计算 logits 并进行生成\n",
    "    with torch.no_grad():\n",
    "        outputs = model(\n",
    "            input_ids=generated_ids,        # (batch_size, seq_len)\n",
    "            attention_mask=inputs[\"attention_mask\"],  # (batch_size, seq_len)\n",
    "            graph_embedding=graph_embedding,      # (batch_size, 512)\n",
    "        )\n",
    "\n",
    "\n",
    "    logits = outputs.logits[:, -1, :]  # 取最后一个 token 的 logits\n",
    "    next_token = torch.argmax(logits, dim=-1, keepdim=True)  # 贪心解码\n",
    "\n",
    "\n",
    "    # ✅ **拼接到已生成序列**\n",
    "    generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
    "\n",
    "    if next_token[:, 0] == tokenizer.eos_token_id:\n",
    "        break\n",
    "\n",
    "# ✅ 解码最终输出\n",
    "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
    "print(\"Generated Response:\", generated_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated Response: How does the code handle combinational logic? What are the signal definitions in the Verilog code for the 4-to-1 multiplexer?\n",
      "The code uses assign statements to handle combinational logic. The first assign statement selects between the four inputs (in0, in1, in2, in3) based on the select signals (s0, s1) and assigns the result to the output (out). The second assign statement uses a ternary operator to check the value of the select signals (s0, s1) and assigns the corresponding input to the output (out). The signal definitions include in0, in1, in2, in3 as data inputs, s0 and s1 as select signals, and out as the output signal.\n",
      "How does the code handle sequential logic? What are the signal definitions in the sequential logic part of the Verilog code?\n",
      "The sequential logic part of the code uses an always block with a sensitivity list that includes posedge clk, indicating that it is a sequential logic block. The output (out) is updated on the rising edge of the clock signal (clk). The input (in0) is also included in the sensitivity list, but since it is not used in the logic, it might be a mistake or an unused input. The sequential logic part is the clocked flip-flop that updates the output (out) based on the current value of the input (in0) and the select signals (s0, s1).\n",
      "What is the function of the circuit described in the Verilog code?\n",
      "The circuit is a 4-to-1 multiplexer with a registered output. It selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop on the rising edge of the clock signal (clk). The output (out) is the value of the selected input stored in the flip-flop.\n",
      "How can the circuit be implemented in hardware?\n",
      "The circuit can be implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The multiplexer can be constructed using AND-OR gates or transmission gates, and the output of the multiplexer can be connected to the D input of the flip-flop. The clock signal (clk) should be connected to the clock input of the flip-flop. The select signals (s0, s1) should be connected to the control inputs of the multiplexer. The data inputs (in0, in1, in2, in3) should be connected to the respective inputs of the multiplexer. The output of the flip-flop (out) should be connected to the output of the circuit. It is important to ensure that the timing constraints for the clock signal (clk) are met to avoid setup and hold time violations. The unused input (in0) in the sensitivity list of the always block might indicate a mistake in the code, as it is not used in the logic. However, it could be a typo or an oversight in the code. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop. The unused input (in0) should be noted as a potential issue but should not affect the functionality of the circuit as described in the code. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit\n"
     ]
    }
   ],
   "source": [
    "generated_ids = inputs[\"input_ids\"]\n",
    "max_new_tokens = 1024\n",
    "for _ in range(max_new_tokens):\n",
    "    # ✅ 计算 logits 并进行生成\n",
    "    with torch.no_grad():\n",
    "        outputs = model(\n",
    "            input_ids=generated_ids,        # (batch_size, seq_len)\n",
    "            attention_mask=inputs[\"attention_mask\"],  # (batch_size, seq_len)\n",
    "            graph_embedding=graph_embedding,      # (batch_size, 512)\n",
    "        )\n",
    "\n",
    "\n",
    "    logits = outputs.logits[:, -1, :]  # 取最后一个 token 的 logits\n",
    "    next_token = torch.argmax(logits, dim=-1, keepdim=True)  # 贪心解码\n",
    "\n",
    "\n",
    "    # ✅ **拼接到已生成序列**\n",
    "    generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
    "\n",
    "    if next_token[:, 0] == tokenizer.eos_token_id:\n",
    "        break\n",
    "\n",
    "# ✅ 解码最终输出\n",
    "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
    "print(\"Generated Response:\", generated_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "# 加载 tokenizer\n",
    "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "\n",
    "# 加载训练好的模型\n",
    "model_path = \"/workspace/model\"\n",
    "model = GraphAwareLM.from_pretrained(model_path)\n",
    "model.eval()  # 设置为推理模式\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}