Okoge-keys commited on
Commit
6df9022
·
verified ·
1 Parent(s): 3b1c7d9

Delete rnn - コピー.ipynb

Browse files
Files changed (1) hide show
  1. rnn - コピー.ipynb +0 -355
rnn - コピー.ipynb DELETED
@@ -1,355 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "テキストデータのTensor化\n",
8
- "1. テキストの読み込み\n",
9
- "2. テキストのトークン化\n",
10
- "3. トークンのインデックス化\n",
11
- "4. 複数テキストのバッチ化\n",
12
- "5. テキストの単語ベクトル化"
13
- ]
14
- },
15
- {
16
- "cell_type": "code",
17
- "execution_count": 1,
18
- "metadata": {},
19
- "outputs": [
20
- {
21
- "name": "stderr",
22
- "output_type": "stream",
23
- "text": [
24
- "c:\\Users\\kenta\\AppData\\Local\\Programs\\Python\\workspace_env\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
25
- " from .autonotebook import tqdm as notebook_tqdm\n"
26
- ]
27
- }
28
- ],
29
- "source": [
30
- "import torch\n",
31
- "import torch.nn as nn\n",
32
- "from torch.utils.data import DataLoader, Dataset, random_split\n",
33
- "from transformers import AutoTokenizer\n",
34
- "from datasets import load_dataset\n",
35
- "import torch.optim as optim\n",
36
- "from tqdm import tqdm"
37
- ]
38
- },
39
- {
40
- "cell_type": "code",
41
- "execution_count": 2,
42
- "metadata": {},
43
- "outputs": [
44
- {
45
- "data": {
46
- "text/plain": [
47
- "Dataset({\n",
48
- " features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
49
- " num_rows: 25000\n",
50
- "})"
51
- ]
52
- },
53
- "execution_count": 2,
54
- "metadata": {},
55
- "output_type": "execute_result"
56
- }
57
- ],
58
- "source": [
59
- "# データセットのロード\n",
60
- "dataset = load_dataset(\"stanfordnlp/imdb\")\n",
61
- "train_texts = dataset['train']['text'][:100]\n",
62
- "train_labels = dataset['train']['label'][:100]\n",
63
- "# test_texts = dataset['test']['text'][100:]\n",
64
- "# test_labels = dataset['test']['label'][100:]\n",
65
- "\n",
66
- "# トークナイザーの準備\n",
67
- "tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')\n",
68
- "\n",
69
- "# テキストのトークン化とインデックス化\n",
70
- "def tokenize_function(text_dataset):\n",
71
- " return tokenizer(text_dataset['text'], \n",
72
- " padding=True,\n",
73
- " truncation=True,\n",
74
- " max_length=256)\n",
75
- "\n",
76
- "train_encodings = dataset['train'].map(tokenize_function, batched=True)\n",
77
- "train_encodings\n",
78
- "# test_encodings = dataset['test'].map(tokenize_function, batched=True)\n"
79
- ]
80
- },
81
- {
82
- "cell_type": "code",
83
- "execution_count": 3,
84
- "metadata": {},
85
- "outputs": [
86
- {
87
- "data": {
88
- "text/plain": [
89
- "['I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered \"controversial\" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, even then it\\'s not shot like some cheaply made porno. While my countrymen mind find it shocking, in reality sex and nudity are a major staple in Swedish cinema. Even Ingmar Bergman, arguably their answer to good old boy John Ford, had sex scenes in his films.<br /><br />I do commend the filmmakers for the fact that any sex shown in the film is shown for artistic purposes rather than just to shock people and make money to be shown in pornographic theaters in America. I AM CURIOUS-YELLOW is a good film for anyone wanting to study the meat and potatoes (no pun intended) of Swedish cinema. But really, this film doesn\\'t have much of a plot.']"
90
- ]
91
- },
92
- "execution_count": 3,
93
- "metadata": {},
94
- "output_type": "execute_result"
95
- }
96
- ],
97
- "source": [
98
- "train_encodings['text'][:1]"
99
- ]
100
- },
101
- {
102
- "cell_type": "code",
103
- "execution_count": 4,
104
- "metadata": {},
105
- "outputs": [
106
- {
107
- "data": {
108
- "text/plain": [
109
- "list"
110
- ]
111
- },
112
- "execution_count": 4,
113
- "metadata": {},
114
- "output_type": "execute_result"
115
- }
116
- ],
117
- "source": [
118
- "type(train_encodings['input_ids'])"
119
- ]
120
- },
121
- {
122
- "cell_type": "code",
123
- "execution_count": 5,
124
- "metadata": {},
125
- "outputs": [],
126
- "source": [
127
- "\n",
128
- "# カスタムデータセットクラス\n",
129
- "class CustomDataset(Dataset):\n",
130
- " def __init__(self, encodings, labels):\n",
131
- " self.encodings = encodings\n",
132
- " self.labels = labels\n",
133
- " \n",
134
- " def __getitem__(self, idx):\n",
135
- " item = {\n",
136
- " 'input_ids': torch.tensor(self.encodings['input_ids'][idx]),\n",
137
- " 'attention_mask': torch.tensor(self.encodings['attention_mask'][idx])\n",
138
- " }\n",
139
- " item['labels'] = torch.tensor(self.labels[idx])\n",
140
- " return item\n",
141
- " \n",
142
- " def __len__(self):\n",
143
- " return len(self.labels)\n",
144
- "\n",
145
- "train_dataset = CustomDataset(train_encodings, train_labels)\n",
146
- "# test_dataset = CustomDataset(test_encodings, test_labels)\n"
147
- ]
148
- },
149
- {
150
- "cell_type": "code",
151
- "execution_count": 6,
152
- "metadata": {},
153
- "outputs": [],
154
- "source": [
155
- "\n",
156
- "# # 訓練データをtrainとvalに分割\n",
157
- "# train_size = int(0.7 * len(train_dataset))\n",
158
- "# val_size = len(train_dataset) - train_size\n",
159
- "# train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])\n",
160
- "\n",
161
- "# データローダー\n",
162
- "train_loader = DataLoader(train_dataset, \n",
163
- " batch_size=32, \n",
164
- " shuffle=True,\n",
165
- " num_workers=8, # 並列データロード\n",
166
- " pin_memory=True, # GPUへの転送を高速化\n",
167
- " # prefetch_factor=2 # 先読み\n",
168
- " )"
169
- ]
170
- },
171
- {
172
- "cell_type": "code",
173
- "execution_count": 7,
174
- "metadata": {},
175
- "outputs": [
176
- {
177
- "ename": "RuntimeError",
178
- "evalue": "DataLoader worker (pid(s) 37820, 28624, 37972, 20972, 21228, 28476, 11500, 38532) exited unexpectedly",
179
- "output_type": "error",
180
- "traceback": [
181
- "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
182
- "\u001b[1;31mEmpty\u001b[0m Traceback (most recent call last)",
183
- "File \u001b[1;32mc:\\Users\\kenta\\AppData\\Local\\Programs\\Python\\workspace_env\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:1251\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[1;34m(self, timeout)\u001b[0m\n\u001b[0;32m 1250\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m-> 1251\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_queue\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1252\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n",
184
- "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\queue.py:179\u001b[0m, in \u001b[0;36mQueue.get\u001b[1;34m(self, block, timeout)\u001b[0m\n\u001b[0;32m 178\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m remaining \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m:\n\u001b[1;32m--> 179\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m Empty\n\u001b[0;32m 180\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnot_empty\u001b[38;5;241m.\u001b[39mwait(remaining)\n",
185
- "\u001b[1;31mEmpty\u001b[0m: ",
186
- "\nThe above exception was the direct cause of the following exception:\n",
187
- "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
188
- "Cell \u001b[1;32mIn[7], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m:\u001b[49m\n\u001b[0;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mprint\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mbreak\u001b[39;49;00m\n",
189
- "File \u001b[1;32mc:\\Users\\kenta\\AppData\\Local\\Programs\\Python\\workspace_env\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:708\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 705\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 706\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m 707\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 708\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 709\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m 710\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[0;32m 711\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable\n\u001b[0;32m 712\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 713\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called\n\u001b[0;32m 714\u001b[0m ):\n",
190
- "File \u001b[1;32mc:\\Users\\kenta\\AppData\\Local\\Programs\\Python\\workspace_env\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:1458\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1455\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_data(data)\n\u001b[0;32m 1457\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_shutdown \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m-> 1458\u001b[0m idx, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1459\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m 1460\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable:\n\u001b[0;32m 1461\u001b[0m \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n",
191
- "File \u001b[1;32mc:\\Users\\kenta\\AppData\\Local\\Programs\\Python\\workspace_env\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:1410\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1408\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[0;32m 1409\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_thread\u001b[38;5;241m.\u001b[39mis_alive():\n\u001b[1;32m-> 1410\u001b[0m success, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_try_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1411\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[0;32m 1412\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
192
- "File \u001b[1;32mc:\\Users\\kenta\\AppData\\Local\\Programs\\Python\\workspace_env\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:1264\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[1;34m(self, timeout)\u001b[0m\n\u001b[0;32m 1262\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(failed_workers) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m 1263\u001b[0m pids_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(w\u001b[38;5;241m.\u001b[39mpid) \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m failed_workers)\n\u001b[1;32m-> 1264\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[0;32m 1265\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataLoader worker (pid(s) \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpids_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) exited unexpectedly\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1266\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01me\u001b[39;00m\n\u001b[0;32m 1267\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, queue\u001b[38;5;241m.\u001b[39mEmpty):\n\u001b[0;32m 1268\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n",
193
- "\u001b[1;31mRuntimeError\u001b[0m: DataLoader worker (pid(s) 37820, 28624, 37972, 20972, 21228, 28476, 11500, 38532) exited unexpectedly"
194
- ]
195
- }
196
- ],
197
- "source": [
198
- "for batch in train_loader:\n",
199
- " print(batch)\n",
200
- " break"
201
- ]
202
- },
203
- {
204
- "cell_type": "code",
205
- "execution_count": 6,
206
- "metadata": {},
207
- "outputs": [],
208
- "source": [
209
- "\n",
210
- "# LSTMモデルの定義\n",
211
- "class LstmClassifier(nn.Module):\n",
212
- " def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=2, dropout=0.5):\n",
213
- " super(LstmClassifier, self).__init__()\n",
214
- " \n",
215
- " # 埋め込み層を追加\n",
216
- " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
217
- " \n",
218
- " # LSTM層\n",
219
- " self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)\n",
220
- " self.dropout = nn.Dropout(dropout)\n",
221
- " self.fc = nn.Linear(hidden_dim, output_dim)\n",
222
- " self.softmax = nn.Softmax(dim=1)\n",
223
- " \n",
224
- " def forward(self, x):\n",
225
- " # 埋め込み層を通す\n",
226
- " embedded = self.embedding(x) # (batch_size, seq_length, embedding_dim)\n",
227
- " \n",
228
- " # LSTM層\n",
229
- " lstm_out, (hn, cn) = self.lstm(embedded)\n",
230
- " final_hidden_state = hn[-1]\n",
231
- " \n",
232
- " # ドロップアウトと全結合層\n",
233
- " x = self.dropout(final_hidden_state)\n",
234
- " x = self.fc(x)\n",
235
- " return self.softmax(x)\n"
236
- ]
237
- },
238
- {
239
- "cell_type": "code",
240
- "execution_count": 7,
241
- "metadata": {},
242
- "outputs": [],
243
- "source": [
244
- "\n",
245
- "# モデルのインスタンスを作成\n",
246
- "# input_dim = tokenizer.model_max_length # bertの埋め込みサイズ\n",
247
- "vocab_size = tokenizer.vocab_size # トークナイザーの語彙サイズ\n",
248
- "embedding_dim = 300 # 埋め込みベクトルの次元数\n",
249
- "hidden_dim = 128\n",
250
- "output_dim = 2\n",
251
- "model = LstmClassifier(vocab_size, embedding_dim, hidden_dim, output_dim)\n",
252
- "\n",
253
- "# 最適化手法と損失関数\n",
254
- "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
255
- "criterion = nn.CrossEntropyLoss()\n"
256
- ]
257
- },
258
- {
259
- "cell_type": "code",
260
- "execution_count": 8,
261
- "metadata": {},
262
- "outputs": [
263
- {
264
- "name": "stdout",
265
- "output_type": "stream",
266
- "text": [
267
- "cuda\n"
268
- ]
269
- },
270
- {
271
- "data": {
272
- "text/plain": [
273
- "LstmClassifier(\n",
274
- " (embedding): Embedding(30522, 300)\n",
275
- " (lstm): LSTM(300, 128, num_layers=2, batch_first=True)\n",
276
- " (dropout): Dropout(p=0.5, inplace=False)\n",
277
- " (fc): Linear(in_features=128, out_features=2, bias=True)\n",
278
- " (softmax): Softmax(dim=1)\n",
279
- ")"
280
- ]
281
- },
282
- "execution_count": 8,
283
- "metadata": {},
284
- "output_type": "execute_result"
285
- }
286
- ],
287
- "source": [
288
- "# 学習ループ\n",
289
- "num_epochs = 1\n",
290
- "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
291
- "print(device)\n",
292
- "\n",
293
- "model.to(device)\n"
294
- ]
295
- },
296
- {
297
- "cell_type": "code",
298
- "execution_count": null,
299
- "metadata": {},
300
- "outputs": [],
301
- "source": [
302
- "\n",
303
- "for epoch in range(num_epochs):\n",
304
- " model.train()\n",
305
- " running_loss = 0.0\n",
306
- " for batch in train_loader:\n",
307
- " print('バッチのごとに処理')\n",
308
- " optimizer.zero_grad()\n",
309
- " print('GPUへ転送開始')\n",
310
- " # データをGPUに転送\n",
311
- " input_ids = batch['input_ids'].to(device)\n",
312
- " print('一つ目転送完了')\n",
313
- " attention_mask = batch['attention_mask'].to(device)\n",
314
- " labels = batch['labels'].to(device)\n",
315
- " # LSTMモデルの予測\n",
316
- " outputs = model(input_ids) # shape:(batch_size, output_dim)\n",
317
- " \n",
318
- " loss = criterion(outputs, labels)\n",
319
- " loss.backward()\n",
320
- " optimizer.step()\n",
321
- "\n",
322
- " running_loss += loss.item()\n",
323
- " print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}\")\n"
324
- ]
325
- },
326
- {
327
- "cell_type": "code",
328
- "execution_count": null,
329
- "metadata": {},
330
- "outputs": [],
331
- "source": []
332
- }
333
- ],
334
- "metadata": {
335
- "kernelspec": {
336
- "display_name": "workspace_env",
337
- "language": "python",
338
- "name": "python3"
339
- },
340
- "language_info": {
341
- "codemirror_mode": {
342
- "name": "ipython",
343
- "version": 3
344
- },
345
- "file_extension": ".py",
346
- "mimetype": "text/x-python",
347
- "name": "python",
348
- "nbconvert_exporter": "python",
349
- "pygments_lexer": "ipython3",
350
- "version": "3.12.9"
351
- }
352
- },
353
- "nbformat": 4,
354
- "nbformat_minor": 2
355
- }