Update README.md
Browse files
README.md
CHANGED
|
@@ -20,6 +20,52 @@ base_model:
|
|
| 20 |
|
| 21 |
## Train
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
## Evaluation
|
| 25 |
|
|
|
|
| 20 |
|
| 21 |
## Train
|
| 22 |
|
| 23 |
+
H/W : colab A100 40GB
|
| 24 |
+
Data : jaeyong2/Ko-emb-PreView
|
| 25 |
+
|
| 26 |
+
```
|
| 27 |
+
model_name = "Alibaba-NLP/gte-multilingual-base"
|
| 28 |
+
dataset = datasets.load_dataset("jaeyong2/Ko-emb-PreView")
|
| 29 |
+
train_dataloader = DataLoader(dataset['train'], batch_size=8, shuffle=True)
|
| 30 |
+
|
| 31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 32 |
+
model = AutoModel.from_pretrained(model_name).to(torch.bfloat16)
|
| 33 |
+
triplet_loss = TripletLoss(margin=1.0)
|
| 34 |
+
|
| 35 |
+
optimizer = AdamW(model.parameters(), lr=5e-5)
|
| 36 |
+
|
| 37 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 38 |
+
model.to(device)
|
| 39 |
+
|
| 40 |
+
for epoch in range(3): # 에포크 반복
|
| 41 |
+
model.train()
|
| 42 |
+
total_loss = 0
|
| 43 |
+
count = 0
|
| 44 |
+
for batch in tqdm(train_dataloader):
|
| 45 |
+
optimizer.zero_grad()
|
| 46 |
+
loss = None
|
| 47 |
+
for index in range(len(batch["context"])):
|
| 48 |
+
anchor_encodings = tokenizer([batch["context"][index]], truncation=True, padding="max_length", max_length=4096, return_tensors="pt")
|
| 49 |
+
positive_encodings = tokenizer([batch["Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt")
|
| 50 |
+
negative_encodings = tokenizer([batch["Fake Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt")
|
| 51 |
+
|
| 52 |
+
anchor_encodings = batch_to_device(anchor_encodings, device)
|
| 53 |
+
positive_encodings = batch_to_device(positive_encodings, device)
|
| 54 |
+
negative_encodings = batch_to_device(negative_encodings, device)
|
| 55 |
+
|
| 56 |
+
# 모델 출력 (임베딩 벡터 생성)
|
| 57 |
+
anchor_output = model(**anchor_encodings)[0][:, 0, :] # [CLS] 토큰의 벡터
|
| 58 |
+
positive_output = model(**positive_encodings)[0][:, 0, :]
|
| 59 |
+
negative_output = model(**negative_encodings)[0][:, 0, :]
|
| 60 |
+
# 삼중항 손실 계산
|
| 61 |
+
if loss==None:
|
| 62 |
+
loss = triplet_loss(anchor_output, positive_output, negative_output)
|
| 63 |
+
else:
|
| 64 |
+
loss += triplet_loss(anchor_output, positive_output, negative_output)
|
| 65 |
+
loss /= len(batch["context"])
|
| 66 |
+
loss.backward()
|
| 67 |
+
optimizer.step()
|
| 68 |
+
```
|
| 69 |
|
| 70 |
## Evaluation
|
| 71 |
|