Model Card for Model ID

ONNX version of jaeyong2/gte-multilingual-base-Ja-embedding. Fine-tune Alibaba-NLP/gte-multilingual-base model in jaeyong2/Ja-emb-PreView dataset for better adaptation in Japanese

Model Details

Alibaba-NLP/gte-multilingual-base

Train

  • Data : jaeyong2/Ja-emb-PreView
import torch
import datasets
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from torch.optim import AdamW
from tqdm import tqdm
from torch import nn

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

def batch_to_device(batch, device):
    return {key: value.to(device) for key, value in batch.items()}

model_name = "Alibaba-NLP/gte-multilingual-base"
dataset = datasets.load_dataset("jaeyong2/Ja-emb-PreView")
train_dataloader = DataLoader(dataset['train'], batch_size=8, shuffle=True)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
triplet_loss = TripletLoss(margin=1.0)

optimizer = AdamW(model.parameters(), lr=5e-5)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
    model = model.to(torch.bfloat16)

for epoch in range(3):
    model.train()
    total_loss = 0
    count = 0

    print(f"\nEpoch {epoch + 1}/3")

    for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}"):
        optimizer.zero_grad()
        loss = None

        for index in range(len(batch["context"])):
            anchor_encodings = tokenizer(
                [batch["context"][index]],
                truncation=True,
                padding="max_length",
                max_length=1024,
                return_tensors="pt"
            )
            positive_encodings = tokenizer(
                [batch["Title"][index]],
                truncation=True,
                padding="max_length",
                max_length=256,
                return_tensors="pt"
            )
            negative_encodings = tokenizer(
                [batch["Fake Title"][index]],
                truncation=True,
                padding="max_length",
                max_length=256,
                return_tensors="pt"
            )

            anchor_encodings = batch_to_device(anchor_encodings, device)
            positive_encodings = batch_to_device(positive_encodings, device)
            negative_encodings = batch_to_device(negative_encodings, device)

            anchor_output = model(**anchor_encodings)[0][:, 0, :]
            positive_output = model(**positive_encodings)[0][:, 0, :]
            negative_output = model(**negative_encodings)[0][:, 0, :]

            if loss == None:
                loss = triplet_loss(anchor_output, positive_output, negative_output)
            else:
                loss += triplet_loss(anchor_output, positive_output, negative_output)

        loss /= len(batch["context"])

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1

    avg_loss = total_loss / count
    print(f"Epoch {epoch + 1} - Average Loss: {avg_loss:.4f}")

Inference

Code :

model = ONNXEmbeddingModel(model_path)

multilingual_texts = [
    "Machine learning is fascinating",
    "機械学習は魅力的です",  # Japanese: Machine learning is fascinating
    "L'apprentissage automatique est fascinant",  # French
]
    
ml_embeddings = model.encode(multilingual_texts, normalize=True)
ml_similarities = model.similarity(ml_embeddings, ml_embeddings)
    
print("Cross-lingual similarities:")
for i, text1 in enumerate(multilingual_texts):
    for j, text2 in enumerate(multilingual_texts):
        if i < j:
            sim = ml_similarities[i, j]
            print(f"  {sim:.4f}: '{text1[:30]}...' <-> '{text2[:30]}...'")

License

Downloads last month
1
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for bunbohue/Japanese-gte-multilingual-base-ONNX

Quantized
(11)
this model

Dataset used to train bunbohue/Japanese-gte-multilingual-base-ONNX