File size: 4,583 Bytes
d65bc08 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | import json
import random
from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation
from torch.utils.data import DataLoader
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
console = Console()
def load_triplets_from_jsonl(file_path):
examples = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
examples.append(
InputExample(texts=[data['anchor'], data['positive'], data['negative']])
)
return examples
def split_data(examples, train_ratio=0.8, seed=42):
random.seed(seed)
random.shuffle(examples)
train_size = int(len(examples) * train_ratio)
return examples[:train_size], examples[train_size:]
def main():
model_name = 'all-MiniLM-L6-v2'
model = SentenceTransformer(model_name)
# Load all triplets from a single JSONL file
all_examples = load_triplets_from_jsonl('triplets.jsonl')
# Split into train and validation
train_examples, val_examples = split_data(all_examples, train_ratio=0.8)
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.TripletLoss(model=model)
val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=16)
anchors = [ex.texts[0] for ex in val_examples]
positives = [ex.texts[1] for ex in val_examples]
negatives = [ex.texts[2] for ex in val_examples]
evaluator = evaluation.TripletEvaluator(
anchors=anchors,
positives=positives,
negatives=negatives,
name='val-triplet-eval',
show_progress_bar=False
)
num_epochs = 1
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)
output_path = 'fine_tuned_sbert_triplet'
console.print(f"[bold]Starting training with {len(train_examples)} triplets, validating on {len(val_examples)} triplets.[/bold]\n")
progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
transient=True,
)
class RichLossEvaluator(evaluation.TripletEvaluator):
def __init__(self, *args, progress_task, **kwargs):
super().__init__(*args, **kwargs)
self.progress_task = progress_task
def __call__(self, model, output_path=None, epoch=-1, steps=-1, **kwargs):
score_dict = super().__call__(model, output_path=output_path, epoch=epoch, steps=steps, **kwargs)
if isinstance(score_dict, dict):
if "accuracy" in score_dict:
val_score = score_dict["accuracy"]
else:
val_score = next(iter(score_dict.values()))
else:
val_score = score_dict
progress.update(self.progress_task, advance=10) # assuming evaluation every 10 steps
progress.console.log(f"Step {steps}: Validation score: {val_score:.4f}")
return score_dict
with progress:
task = progress.add_task("[green]Training...[/green]", total=num_epochs * len(train_dataloader))
rich_evaluator = RichLossEvaluator(
anchors=anchors,
positives=positives,
negatives=negatives,
name='val-triplet-eval',
show_progress_bar=False,
progress_task=task
)
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=num_epochs,
warmup_steps=warmup_steps,
evaluator=rich_evaluator,
evaluation_steps=10,
output_path=output_path,
show_progress_bar=False
)
progress.update(task, completed=num_epochs * len(train_dataloader))
table = Table(title="Training Summary")
table.add_column("Metric", style="cyan", no_wrap=True)
table.add_column("Value", style="magenta")
table.add_row("Model Name", model_name)
table.add_row("Total Triplets", str(len(all_examples)))
table.add_row("Training Triplets", str(len(train_examples)))
table.add_row("Validation Triplets", str(len(val_examples)))
table.add_row("Epochs", str(num_epochs))
table.add_row("Output Path", output_path)
console.print(table)
console.print("[bold green]Model fine-tuned and saved successfully![/bold green]")
if __name__ == "__main__":
main()
|