Yuchan commited on
Commit
6075db7
ยท
verified ยท
1 Parent(s): 128be27

Update Model_torch.py

Browse files
Files changed (1) hide show
  1. Model_torch.py +20 -11
Model_torch.py CHANGED
@@ -160,34 +160,43 @@ class ReLM(nn.Module):
160
  logits = x @ self.token_embedding.weight.T
161
  return logits
162
 
163
- # ===============================
164
- # ํ•™์Šต
165
- # ===============================
166
  model = ReLM(vocab_size, max_len, 128, 2).to(device)
167
  optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
168
  scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
169
  loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id)
170
 
 
 
 
 
171
  epochs = 1
172
  for epoch in range(epochs):
173
  model.train()
174
  total_loss = 0
175
- for step,(x,y) in enumerate(dataloader):
176
- x,y = x.to(device), y.to(device)
177
  optimizer.zero_grad()
178
- logits = model(x)
179
- loss = loss_fn(logits.view(-1,vocab_size), y.view(-1))
180
- loss.backward()
181
- torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
182
- optimizer.step()
 
 
 
 
 
 
183
  total_loss += loss.item()
184
  if step % 100 == 0:
185
  print(f"Epoch {epoch+1}, Step {step}, Loss: {loss.item():.4f}")
 
186
  scheduler.step()
187
  print(f"Epoch {epoch+1} ์™„๋ฃŒ, ํ‰๊ท  Loss: {total_loss/len(dataloader):.4f}")
188
 
189
  torch.save(model.state_dict(), "relm_model.pth")
190
- print("๋ชจ๋ธ ์ €์žฅ ์™„๋ฃŒ!")
191
 
192
  # ===============================
193
  # Top-p ์ƒ˜ํ”Œ๋ง ์ƒ์„ฑ
 
160
  logits = x @ self.token_embedding.weight.T
161
  return logits
162
 
163
+ # ๋ชจ๋ธ, ์˜ตํ‹ฐ๋งˆ์ด์ €, ์Šค์ผ€์ค„๋Ÿฌ, ์†์‹ค ํ•จ์ˆ˜
 
 
164
  model = ReLM(vocab_size, max_len, 128, 2).to(device)
165
  optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
166
  scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
167
  loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id)
168
 
169
+ # ์ •์  ๊ทธ๋ž˜ํ”„ ์ปดํŒŒ์ผ
170
+ model = torch.compile(model, mode="default")
171
+
172
+ scaler = torch.cuda.amp.GradScaler()
173
  epochs = 1
174
  for epoch in range(epochs):
175
  model.train()
176
  total_loss = 0
177
+ for step, (x, y) in enumerate(dataloader):
178
+ x, y = x.to(device), y.to(device)
179
  optimizer.zero_grad()
180
+
181
+ with torch.cuda.amp.autocast(): # mixed precision
182
+ logits = model(x)
183
+ loss = loss_fn(logits.view(-1, vocab_size), y.view(-1))
184
+
185
+ scaler.scale(loss).backward()
186
+ scaler.unscale_(optimizer)
187
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
188
+ scaler.step(optimizer)
189
+ scaler.update()
190
+
191
  total_loss += loss.item()
192
  if step % 100 == 0:
193
  print(f"Epoch {epoch+1}, Step {step}, Loss: {loss.item():.4f}")
194
+
195
  scheduler.step()
196
  print(f"Epoch {epoch+1} ์™„๋ฃŒ, ํ‰๊ท  Loss: {total_loss/len(dataloader):.4f}")
197
 
198
  torch.save(model.state_dict(), "relm_model.pth")
199
+ print("โœ… ๋ชจ๋ธ ์ €์žฅ ์™„๋ฃŒ!")
200
 
201
  # ===============================
202
  # Top-p ์ƒ˜ํ”Œ๋ง ์ƒ์„ฑ