win2win commited on
Commit
784bcca
·
verified ·
1 Parent(s): 8088f36

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +33 -0
train.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
2
+ from torch.utils.data import DataLoader
3
+ import json
4
+ import numpy as np
5
+
6
+ # 1. Load data
7
+ with open('data/listings.json') as f:
8
+ train_data = json.load(f)
9
+
10
+ # 2. Prepare examples
11
+ train_examples = []
12
+ for item in train_data:
13
+ train_examples.append(InputExample(
14
+ texts=[item['text']],
15
+ label=item['category_id']
16
+ ))
17
+
18
+ # 3. Initialize model
19
+ model = SentenceTransformer('all-MiniLM-L6-v2')
20
+
21
+ # 4. Train with contrastive loss
22
+ train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
23
+ loss = losses.ContrastiveLoss(model=model)
24
+
25
+ model.fit(
26
+ train_objectives=[(train_dataloader, loss)],
27
+ epochs=3,
28
+ warmup_steps=100
29
+ )
30
+
31
+ # 5. Save model
32
+ model.save('models/ad_categorizer')
33
+ print("Training complete! Model saved.")