File size: 835 Bytes
784bcca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
from torch.utils.data import DataLoader
import json
import numpy as np

# 1. Load data
with open('data/listings.json') as f:
    train_data = json.load(f)

# 2. Prepare examples
train_examples = []
for item in train_data:
    train_examples.append(InputExample(
        texts=[item['text']],
        label=item['category_id']
    ))

# 3. Initialize model
model = SentenceTransformer('all-MiniLM-L6-v2')

# 4. Train with contrastive loss
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
loss = losses.ContrastiveLoss(model=model)

model.fit(
    train_objectives=[(train_dataloader, loss)],
    epochs=3,
    warmup_steps=100
)

# 5. Save model
model.save('models/ad_categorizer')
print("Training complete! Model saved.")