Yuchan
commited on
Update Model_torch.py
Browse files- Model_torch.py +20 -11
Model_torch.py
CHANGED
|
@@ -160,34 +160,43 @@ class ReLM(nn.Module):
|
|
| 160 |
logits = x @ self.token_embedding.weight.T
|
| 161 |
return logits
|
| 162 |
|
| 163 |
-
#
|
| 164 |
-
# ํ์ต
|
| 165 |
-
# ===============================
|
| 166 |
model = ReLM(vocab_size, max_len, 128, 2).to(device)
|
| 167 |
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
|
| 168 |
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
|
| 169 |
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id)
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
epochs = 1
|
| 172 |
for epoch in range(epochs):
|
| 173 |
model.train()
|
| 174 |
total_loss = 0
|
| 175 |
-
for step,(x,y) in enumerate(dataloader):
|
| 176 |
-
x,y = x.to(device), y.to(device)
|
| 177 |
optimizer.zero_grad()
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
total_loss += loss.item()
|
| 184 |
if step % 100 == 0:
|
| 185 |
print(f"Epoch {epoch+1}, Step {step}, Loss: {loss.item():.4f}")
|
|
|
|
| 186 |
scheduler.step()
|
| 187 |
print(f"Epoch {epoch+1} ์๋ฃ, ํ๊ท Loss: {total_loss/len(dataloader):.4f}")
|
| 188 |
|
| 189 |
torch.save(model.state_dict(), "relm_model.pth")
|
| 190 |
-
print("๋ชจ๋ธ ์ ์ฅ ์๋ฃ!")
|
| 191 |
|
| 192 |
# ===============================
|
| 193 |
# Top-p ์ํ๋ง ์์ฑ
|
|
|
|
| 160 |
logits = x @ self.token_embedding.weight.T
|
| 161 |
return logits
|
| 162 |
|
| 163 |
+
# ๋ชจ๋ธ, ์ตํฐ๋ง์ด์ , ์ค์ผ์ค๋ฌ, ์์ค ํจ์
|
|
|
|
|
|
|
| 164 |
model = ReLM(vocab_size, max_len, 128, 2).to(device)
|
| 165 |
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
|
| 166 |
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
|
| 167 |
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id)
|
| 168 |
|
| 169 |
+
# ์ ์ ๊ทธ๋ํ ์ปดํ์ผ
|
| 170 |
+
model = torch.compile(model, mode="default")
|
| 171 |
+
|
| 172 |
+
scaler = torch.cuda.amp.GradScaler()
|
| 173 |
epochs = 1
|
| 174 |
for epoch in range(epochs):
|
| 175 |
model.train()
|
| 176 |
total_loss = 0
|
| 177 |
+
for step, (x, y) in enumerate(dataloader):
|
| 178 |
+
x, y = x.to(device), y.to(device)
|
| 179 |
optimizer.zero_grad()
|
| 180 |
+
|
| 181 |
+
with torch.cuda.amp.autocast(): # mixed precision
|
| 182 |
+
logits = model(x)
|
| 183 |
+
loss = loss_fn(logits.view(-1, vocab_size), y.view(-1))
|
| 184 |
+
|
| 185 |
+
scaler.scale(loss).backward()
|
| 186 |
+
scaler.unscale_(optimizer)
|
| 187 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 188 |
+
scaler.step(optimizer)
|
| 189 |
+
scaler.update()
|
| 190 |
+
|
| 191 |
total_loss += loss.item()
|
| 192 |
if step % 100 == 0:
|
| 193 |
print(f"Epoch {epoch+1}, Step {step}, Loss: {loss.item():.4f}")
|
| 194 |
+
|
| 195 |
scheduler.step()
|
| 196 |
print(f"Epoch {epoch+1} ์๋ฃ, ํ๊ท Loss: {total_loss/len(dataloader):.4f}")
|
| 197 |
|
| 198 |
torch.save(model.state_dict(), "relm_model.pth")
|
| 199 |
+
print("โ
๋ชจ๋ธ ์ ์ฅ ์๋ฃ!")
|
| 200 |
|
| 201 |
# ===============================
|
| 202 |
# Top-p ์ํ๋ง ์์ฑ
|