Model Card for Model ID
ONNX version of jaeyong2/gte-multilingual-base-Ja-embedding. Fine-tune Alibaba-NLP/gte-multilingual-base model in jaeyong2/Ja-emb-PreView dataset for better adaptation in Japanese
Model Details
Alibaba-NLP/gte-multilingual-base
Train
- Data : jaeyong2/Ja-emb-PreView
import torch
import datasets
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from torch.optim import AdamW
from tqdm import tqdm
from torch import nn
class TripletLoss(nn.Module):
def __init__(self, margin=1.0):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
distance_positive = (anchor - positive).pow(2).sum(1)
distance_negative = (anchor - negative).pow(2).sum(1)
losses = torch.relu(distance_positive - distance_negative + self.margin)
return losses.mean()
def batch_to_device(batch, device):
return {key: value.to(device) for key, value in batch.items()}
model_name = "Alibaba-NLP/gte-multilingual-base"
dataset = datasets.load_dataset("jaeyong2/Ja-emb-PreView")
train_dataloader = DataLoader(dataset['train'], batch_size=8, shuffle=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
triplet_loss = TripletLoss(margin=1.0)
optimizer = AdamW(model.parameters(), lr=5e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
model = model.to(torch.bfloat16)
for epoch in range(3):
model.train()
total_loss = 0
count = 0
print(f"\nEpoch {epoch + 1}/3")
for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}"):
optimizer.zero_grad()
loss = None
for index in range(len(batch["context"])):
anchor_encodings = tokenizer(
[batch["context"][index]],
truncation=True,
padding="max_length",
max_length=1024,
return_tensors="pt"
)
positive_encodings = tokenizer(
[batch["Title"][index]],
truncation=True,
padding="max_length",
max_length=256,
return_tensors="pt"
)
negative_encodings = tokenizer(
[batch["Fake Title"][index]],
truncation=True,
padding="max_length",
max_length=256,
return_tensors="pt"
)
anchor_encodings = batch_to_device(anchor_encodings, device)
positive_encodings = batch_to_device(positive_encodings, device)
negative_encodings = batch_to_device(negative_encodings, device)
anchor_output = model(**anchor_encodings)[0][:, 0, :]
positive_output = model(**positive_encodings)[0][:, 0, :]
negative_output = model(**negative_encodings)[0][:, 0, :]
if loss == None:
loss = triplet_loss(anchor_output, positive_output, negative_output)
else:
loss += triplet_loss(anchor_output, positive_output, negative_output)
loss /= len(batch["context"])
loss.backward()
optimizer.step()
total_loss += loss.item()
count += 1
avg_loss = total_loss / count
print(f"Epoch {epoch + 1} - Average Loss: {avg_loss:.4f}")
Inference
Code :
model = ONNXEmbeddingModel(model_path)
multilingual_texts = [
"Machine learning is fascinating",
"機械学習は魅力的です", # Japanese: Machine learning is fascinating
"L'apprentissage automatique est fascinant", # French
]
ml_embeddings = model.encode(multilingual_texts, normalize=True)
ml_similarities = model.similarity(ml_embeddings, ml_embeddings)
print("Cross-lingual similarities:")
for i, text1 in enumerate(multilingual_texts):
for j, text2 in enumerate(multilingual_texts):
if i < j:
sim = ml_similarities[i, j]
print(f" {sim:.4f}: '{text1[:30]}...' <-> '{text2[:30]}...'")
License
- Alibaba-NLP/gte-multilingual-base : https://choosealicense.com/licenses/apache-2.0/
- Downloads last month
- 1
Model tree for bunbohue/Japanese-gte-multilingual-base-ONNX
Base model
Alibaba-NLP/gte-multilingual-base