Trouter-Library commited on
Commit
4c5e9dd
·
verified ·
1 Parent(s): aaf8d08

Create train_embeddings.py

Browse files
Files changed (1) hide show
  1. train_embeddings.py +337 -0
train_embeddings.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-V1-Embeddings Training Script
3
+ Train a lightweight embedding model for semantic similarity and retrieval
4
+ """
5
+
6
+ import json
7
+ import logging
8
+ from typing import List, Dict, Tuple
9
+ from pathlib import Path
10
+ from datetime import datetime
11
+
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(levelname)s - %(message)s'
15
+ )
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class EmbeddingsTrainer:
20
+ """Train embeddings model for Helion-V1-Embeddings."""
21
+
22
+ def __init__(
23
+ self,
24
+ base_model: str = "sentence-transformers/all-MiniLM-L6-v2",
25
+ output_path: str = "./helion-embeddings-output"
26
+ ):
27
+ self.base_model = base_model
28
+ self.output_path = Path(output_path)
29
+ self.output_path.mkdir(parents=True, exist_ok=True)
30
+
31
+ def prepare_training_data(self) -> List[Dict]:
32
+ """
33
+ Prepare training data for embeddings.
34
+ Format: sentence pairs with similarity scores.
35
+ """
36
+ training_examples = [
37
+ # High similarity pairs
38
+ {
39
+ "sentence1": "How do I reset my password?",
40
+ "sentence2": "What's the password reset process?",
41
+ "score": 0.95
42
+ },
43
+ {
44
+ "sentence1": "Machine learning training methods",
45
+ "sentence2": "How to train ML models",
46
+ "score": 0.90
47
+ },
48
+ {
49
+ "sentence1": "Python programming tutorial",
50
+ "sentence2": "Learn Python coding",
51
+ "score": 0.88
52
+ },
53
+
54
+ # Medium similarity pairs
55
+ {
56
+ "sentence1": "Install Python on Windows",
57
+ "sentence2": "Python setup guide",
58
+ "score": 0.70
59
+ },
60
+ {
61
+ "sentence1": "Best restaurants in Paris",
62
+ "sentence2": "Where to eat in France",
63
+ "score": 0.65
64
+ },
65
+
66
+ # Low similarity pairs
67
+ {
68
+ "sentence1": "How to bake cookies",
69
+ "sentence2": "Machine learning algorithms",
70
+ "score": 0.10
71
+ },
72
+ {
73
+ "sentence1": "Weather forecast tomorrow",
74
+ "sentence2": "Stock market analysis",
75
+ "score": 0.05
76
+ }
77
+ ]
78
+
79
+ logger.info(f"Prepared {len(training_examples)} training examples")
80
+ return training_examples
81
+
82
+ def create_contrastive_pairs(self) -> List[Tuple[str, str]]:
83
+ """
84
+ Create pairs for contrastive learning.
85
+ Format: (anchor, positive) pairs.
86
+ """
87
+ pairs = [
88
+ ("What is machine learning?", "Machine learning explained simply"),
89
+ ("How to learn Python?", "Python learning resources"),
90
+ ("Best coding practices", "Software development best practices"),
91
+ ("Data science tutorial", "Learn data science basics"),
92
+ ("Natural language processing", "NLP fundamentals guide"),
93
+ ("Deep learning introduction", "Getting started with deep learning"),
94
+ ("Web development guide", "How to build websites"),
95
+ ("Database design principles", "SQL database design tutorial"),
96
+ ("Cloud computing basics", "Introduction to cloud services"),
97
+ ("API development guide", "How to create REST APIs"),
98
+ ]
99
+
100
+ logger.info(f"Created {len(pairs)} contrastive pairs")
101
+ return pairs
102
+
103
+ def train_model(
104
+ self,
105
+ train_examples: List[Dict] = None,
106
+ epochs: int = 3,
107
+ batch_size: int = 16,
108
+ warmup_steps: int = 100
109
+ ):
110
+ """
111
+ Train the embeddings model.
112
+
113
+ Args:
114
+ train_examples: Training data (if None, uses default)
115
+ epochs: Number of training epochs
116
+ batch_size: Batch size for training
117
+ warmup_steps: Warmup steps for learning rate
118
+ """
119
+ try:
120
+ from sentence_transformers import (
121
+ SentenceTransformer,
122
+ InputExample,
123
+ losses,
124
+ evaluation
125
+ )
126
+ from torch.utils.data import DataLoader
127
+
128
+ logger.info("Loading base model...")
129
+ model = SentenceTransformer(self.base_model)
130
+
131
+ # Prepare data
132
+ if train_examples is None:
133
+ train_examples = self.prepare_training_data()
134
+
135
+ # Convert to InputExample format
136
+ train_data = []
137
+ for example in train_examples:
138
+ train_data.append(InputExample(
139
+ texts=[example["sentence1"], example["sentence2"]],
140
+ label=example["score"]
141
+ ))
142
+
143
+ # Create DataLoader
144
+ train_dataloader = DataLoader(
145
+ train_data,
146
+ shuffle=True,
147
+ batch_size=batch_size
148
+ )
149
+
150
+ # Define loss function
151
+ train_loss = losses.CosineSimilarityLoss(model)
152
+
153
+ # Training
154
+ logger.info("Starting training...")
155
+ model.fit(
156
+ train_objectives=[(train_dataloader, train_loss)],
157
+ epochs=epochs,
158
+ warmup_steps=warmup_steps,
159
+ output_path=str(self.output_path),
160
+ show_progress_bar=True,
161
+ save_best_model=True
162
+ )
163
+
164
+ logger.info(f"✅ Training complete! Model saved to {self.output_path}")
165
+
166
+ return model
167
+
168
+ except ImportError:
169
+ logger.error("sentence-transformers not installed. Install with: pip install sentence-transformers")
170
+ return None
171
+ except Exception as e:
172
+ logger.error(f"Training failed: {e}")
173
+ return None
174
+
175
+ def evaluate_model(self, model, test_pairs: List[Tuple[str, str, float]] = None):
176
+ """
177
+ Evaluate the trained model.
178
+
179
+ Args:
180
+ model: Trained SentenceTransformer model
181
+ test_pairs: List of (sentence1, sentence2, expected_similarity)
182
+ """
183
+ from sentence_transformers import util
184
+
185
+ if test_pairs is None:
186
+ # Default test pairs
187
+ test_pairs = [
188
+ ("How to code?", "Coding tutorial", 0.85),
189
+ ("Weather today", "Stock prices", 0.1),
190
+ ("Machine learning", "AI and ML", 0.95),
191
+ ]
192
+
193
+ logger.info("Evaluating model...")
194
+
195
+ total_error = 0
196
+ for sent1, sent2, expected in test_pairs:
197
+ emb1 = model.encode(sent1)
198
+ emb2 = model.encode(sent2)
199
+ similarity = float(util.cos_sim(emb1, emb2)[0][0])
200
+ error = abs(similarity - expected)
201
+ total_error += error
202
+
203
+ logger.info(f"'{sent1}' <-> '{sent2}'")
204
+ logger.info(f" Expected: {expected:.2f}, Got: {similarity:.2f}, Error: {error:.2f}")
205
+
206
+ avg_error = total_error / len(test_pairs)
207
+ logger.info(f"Average error: {avg_error:.3f}")
208
+
209
+ return avg_error
210
+
211
+ def create_config_files(self):
212
+ """Create necessary configuration files."""
213
+
214
+ # Sentence transformers config
215
+ config = {
216
+ "__version__": {
217
+ "sentence_transformers": "2.2.2",
218
+ "transformers": "4.36.0",
219
+ "pytorch": "2.0.0"
220
+ },
221
+ "prompts": {},
222
+ "default_prompt_name": None,
223
+ "similarity_fn_name": "cosine",
224
+ "max_seq_length": 256,
225
+ "do_lower_case": False
226
+ }
227
+
228
+ with open(self.output_path / "config_sentence_transformers.json", 'w') as f:
229
+ json.dump(config, f, indent=2)
230
+
231
+ # Modules configuration
232
+ modules = [
233
+ {
234
+ "idx": 0,
235
+ "name": "0",
236
+ "path": "",
237
+ "type": "sentence_transformers.models.Transformer"
238
+ },
239
+ {
240
+ "idx": 1,
241
+ "name": "1",
242
+ "path": "1_Pooling",
243
+ "type": "sentence_transformers.models.Pooling"
244
+ },
245
+ {
246
+ "idx": 2,
247
+ "name": "2",
248
+ "path": "2_Normalize",
249
+ "type": "sentence_transformers.models.Normalize"
250
+ }
251
+ ]
252
+
253
+ with open(self.output_path / "modules.json", 'w') as f:
254
+ json.dump(modules, f, indent=2)
255
+
256
+ logger.info("✅ Configuration files created")
257
+
258
+
259
+ def main():
260
+ """Main training function."""
261
+ import argparse
262
+
263
+ parser = argparse.ArgumentParser(
264
+ description="Train Helion-V1-Embeddings model"
265
+ )
266
+ parser.add_argument(
267
+ "--base-model",
268
+ default="sentence-transformers/all-MiniLM-L6-v2",
269
+ help="Base model to fine-tune"
270
+ )
271
+ parser.add_argument(
272
+ "--output",
273
+ default="./helion-embeddings-output",
274
+ help="Output directory"
275
+ )
276
+ parser.add_argument(
277
+ "--epochs",
278
+ type=int,
279
+ default=3,
280
+ help="Number of training epochs"
281
+ )
282
+ parser.add_argument(
283
+ "--batch-size",
284
+ type=int,
285
+ default=16,
286
+ help="Batch size"
287
+ )
288
+ parser.add_argument(
289
+ "--data-file",
290
+ type=str,
291
+ help="Path to training data JSON file"
292
+ )
293
+
294
+ args = parser.parse_args()
295
+
296
+ # Create trainer
297
+ trainer = EmbeddingsTrainer(
298
+ base_model=args.base_model,
299
+ output_path=args.output
300
+ )
301
+
302
+ # Load custom data if provided
303
+ train_examples = None
304
+ if args.data_file:
305
+ with open(args.data_file, 'r') as f:
306
+ train_examples = json.load(f)
307
+ logger.info(f"Loaded {len(train_examples)} examples from {args.data_file}")
308
+
309
+ # Train model
310
+ model = trainer.train_model(
311
+ train_examples=train_examples,
312
+ epochs=args.epochs,
313
+ batch_size=args.batch_size
314
+ )
315
+
316
+ if model:
317
+ # Evaluate
318
+ trainer.evaluate_model(model)
319
+
320
+ # Create config files
321
+ trainer.create_config_files()
322
+
323
+ print("\n" + "="*60)
324
+ print("✅ Helion-V1-Embeddings Training Complete!")
325
+ print("="*60)
326
+ print(f"📁 Model saved to: {args.output}")
327
+ print("\n💡 Test your model:")
328
+ print("```python")
329
+ print("from sentence_transformers import SentenceTransformer")
330
+ print(f"model = SentenceTransformer('{args.output}')")
331
+ print("embeddings = model.encode(['Hello world'])")
332
+ print("```")
333
+ print("="*60)
334
+
335
+
336
+ if __name__ == "__main__":
337
+ main()