Pringled commited on
Commit
e806337
·
verified ·
1 Parent(s): 524bccc

Replace inline training script with link to train.py

Browse files
Files changed (1) hide show
  1. README.md +4 -309
README.md CHANGED
@@ -84,316 +84,11 @@ CoIR covers a broad range of code retrieval scenarios. For the use case of findi
84
 
85
  ## Reproducibility
86
 
87
- The following script reproduces this model end-to-end. It requires the tokenlearn training data from `minishlab/tokenlearn-cornstack-docs-coderankembed` and `minishlab/tokenlearn-cornstack-queries-coderankembed` (20k samples per language used).
88
 
89
- ```python
90
- """Reproduction script for potion-code-16M.
91
-
92
- Runs the full pipeline: distill → tokenlearn → contrastive fine-tuning.
93
-
94
- Requirements:
95
- pip install model2vec tokenlearn sentence-transformers datasets skeletoken einops
96
-
97
- The three model checkpoints are saved to:
98
- ./models/potion-code-16M-distilled
99
- ./models/potion-code-16M-tokenlearn
100
- ./models/potion-code-16M-contrastive ← final model
101
- """
102
-
103
- from __future__ import annotations
104
-
105
- import logging
106
- import random
107
-
108
- import numpy as np
109
- import torch
110
- from datasets import Dataset, concatenate_datasets, load_dataset
111
- from huggingface_hub import snapshot_download
112
- from model2vec import StaticModel
113
- from model2vec.distill import distill_from_model
114
- from model2vec.distill.inference import post_process_embeddings
115
- from pathlib import Path
116
- from sentence_transformers import (
117
- SentenceTransformer,
118
- SentenceTransformerTrainer,
119
- SentenceTransformerTrainingArguments,
120
- )
121
- from sentence_transformers.losses import MultipleNegativesRankingLoss
122
- from sentence_transformers.models import StaticEmbedding
123
- from sentence_transformers.training_args import BatchSamplers
124
- from skeletoken import TokenizerModel
125
- from sklearn.decomposition import PCA
126
- from tokenlearn.losses import Loss
127
- from tokenlearn.model import StaticModelForFineTuning
128
- from tokenlearn.utils import create_vocab
129
- from transformers import AutoModel, AutoTokenizer
130
-
131
- logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
132
- logger = logging.getLogger(__name__)
133
-
134
- # ---------------------------------------------------------------------------
135
- # Hyperparameters
136
- # ---------------------------------------------------------------------------
137
-
138
- TEACHER_MODEL = "nomic-ai/CodeRankEmbed"
139
- OUTPUT_DIR = Path("models")
140
-
141
- # Distill
142
- VOCAB_SIZE = 42_000 # extra tokens mined from CornStack → ~62.5k total → ~16M params
143
- PCA_DIMS = 256
144
- SIF_COEFFICIENT = 1e-4
145
-
146
- # Tokenlearn
147
- TOKENLEARN_DOCS_DATASET = "minishlab/tokenlearn-cornstack-docs-coderankembed"
148
- TOKENLEARN_QUERIES_DATASET = "minishlab/tokenlearn-cornstack-queries-coderankembed"
149
- TOKENLEARN_LANGUAGES = ["go", "java", "javascript", "php", "python", "ruby"]
150
- TOKENLEARN_MAX_PER_LANGUAGE = 20_000 # 20k docs + 20k queries × 6 langs = 240k total
151
- TOKENLEARN_LR = 1e-3
152
- TOKENLEARN_MAX_EPOCHS = 20 # early stopping (patience=5) typically kicks in earlier
153
- TOKENLEARN_BATCH_SIZE = 128
154
-
155
- # Contrastive
156
- CORNSTACK_DATASETS = {
157
- "python": "nomic-ai/cornstack-python-v1",
158
- "java": "nomic-ai/cornstack-java-v1",
159
- "php": "nomic-ai/cornstack-php-v1",
160
- "go": "nomic-ai/cornstack-go-v1",
161
- "javascript": "nomic-ai/cornstack-javascript-v1",
162
- "ruby": "nomic-ai/cornstack-ruby-v1",
163
- }
164
- CONTRASTIVE_MAX_PER_LANGUAGE = 20_000 # 20k × 6 langs = 120k pairs total
165
- CONTRASTIVE_LR = 5e-3
166
- CONTRASTIVE_EPOCHS = 3
167
- CONTRASTIVE_BATCH_SIZE = 512
168
- CONTRASTIVE_SEED = 42
169
-
170
-
171
- # ---------------------------------------------------------------------------
172
- # Helpers
173
- # ---------------------------------------------------------------------------
174
-
175
- def apply_post_sif(model: StaticModel, pca_dims: int, sif_coefficient: float) -> StaticModel:
176
- embeddings_np = model.embedding.astype(np.float32)
177
- processed, weights = post_process_embeddings(
178
- embeddings_np, pca_dims=pca_dims, sif_coefficient=sif_coefficient
179
- )
180
- logger.info("post_process_embeddings: %s → %s", embeddings_np.shape, processed.shape)
181
- model.embedding = processed
182
- model.weights = weights
183
- return model
184
-
185
-
186
- # ---------------------------------------------------------------------------
187
- # Step 1: Distill
188
- # ---------------------------------------------------------------------------
189
-
190
- def run_distill(save_path: Path) -> None:
191
- logger.info("Downloading %s ...", TEACHER_MODEL)
192
- local_path = snapshot_download(TEACHER_MODEL)
193
- model = AutoModel.from_pretrained(local_path, trust_remote_code=True)
194
- tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True, use_fast=True)
195
-
196
- # Load tokenlearn corpus texts for vocab mining (docs + queries, 20k/lang)
197
- logger.info("Loading texts for vocabulary mining ...")
198
- shards = []
199
- for lang in TOKENLEARN_LANGUAGES:
200
- docs = load_dataset(TOKENLEARN_DOCS_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
201
- queries = load_dataset(TOKENLEARN_QUERIES_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
202
- shards.extend([docs, queries])
203
- corpus = concatenate_datasets(shards)
204
- texts: list[str] = list(corpus["text"])
205
- logger.info("Loaded %d texts for vocab mining.", len(texts))
206
-
207
- logger.info("Mining vocabulary (target size=%d) ...", VOCAB_SIZE)
208
- vocab = create_vocab(texts=texts, vocab_size=VOCAB_SIZE)
209
- logger.info("Mined %d tokens.", len(vocab))
210
-
211
- # Filter: keep only new single-token entries not already in CodeRankEmbed vocabulary.
212
- tokenizer_model = TokenizerModel.from_transformers_tokenizer(tokenizer).prune_added_tokens()
213
- preprocessor = tokenizer_model.preprocessor
214
- seen = set(tokenizer_model.sorted_vocabulary)
215
- filtered = []
216
- for token in vocab:
217
- preprocessed = preprocessor.preprocess(token)
218
- if len(preprocessed) == 1 and preprocessed[0] not in seen:
219
- seen.add(preprocessed[0])
220
- filtered.append(preprocessed[0])
221
- logger.info("Vocabulary after filtering: %d tokens added to CodeRankEmbed.", len(filtered))
222
-
223
- # NomicBERT requires monkey-patched embedding accessors.
224
- model.get_input_embeddings = lambda: model.embeddings.word_embeddings
225
- model.set_input_embeddings = lambda v: setattr(model.embeddings, "word_embeddings", v)
226
-
227
- logger.info("Distilling (pca_dims=%d, sif=%g) ...", PCA_DIMS, SIF_COEFFICIENT)
228
- static_model = distill_from_model(
229
- model=model,
230
- tokenizer=tokenizer,
231
- vocabulary=filtered,
232
- pca_dims=PCA_DIMS,
233
- sif_coefficient=SIF_COEFFICIENT,
234
- pooling="mean",
235
- quantize_to="float32",
236
- )
237
-
238
- save_path.mkdir(parents=True, exist_ok=True)
239
- static_model.save_pretrained(str(save_path))
240
- logger.info("Distilled model saved to %s (vocab=%d, dims=%d)",
241
- save_path, static_model.embedding.shape[0], static_model.embedding.shape[1])
242
-
243
-
244
- # ---------------------------------------------------------------------------
245
- # Step 2: Tokenlearn
246
- # ---------------------------------------------------------------------------
247
-
248
- def run_tokenlearn(base_model_path: Path, save_path: Path) -> None:
249
- # Load 20k docs + 20k queries per language → 240k total
250
- logger.info("Loading tokenlearn data (docs + queries, %d/lang × %d langs) ...",
251
- TOKENLEARN_MAX_PER_LANGUAGE, len(TOKENLEARN_LANGUAGES))
252
- shards = []
253
- for lang in TOKENLEARN_LANGUAGES:
254
- docs = load_dataset(TOKENLEARN_DOCS_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
255
- queries = load_dataset(TOKENLEARN_QUERIES_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
256
- shards.extend([docs, queries])
257
- dataset = concatenate_datasets(shards)
258
- logger.info("Total samples: %d", len(dataset))
259
-
260
- train_txt: list[str] = list(dataset["text"])
261
- train_vec = np.array(dataset["embedding"], dtype=np.float32)
262
- non_nan_mask = ~np.isnan(train_vec).any(axis=1)
263
- train_txt = np.array(train_txt)[non_nan_mask].tolist()
264
- train_vec = train_vec[non_nan_mask]
265
- logger.info("Loaded %d samples, raw vector shape: %s", len(train_txt), train_vec.shape)
266
-
267
- logger.info("Fitting PCA to %d dims ...", PCA_DIMS)
268
- pca = PCA(n_components=PCA_DIMS)
269
- train_vec = pca.fit_transform(train_vec)
270
- logger.info("Explained variance: %.4f. Shape: %s",
271
- pca.explained_variance_ratio_.cumsum()[-1], train_vec.shape)
272
-
273
- logger.info("Loading base model from %s ...", base_model_path)
274
- base_model = StaticModel.from_pretrained(str(base_model_path), force_download=False)
275
- if base_model.embedding.dtype != np.float32:
276
- base_model.embedding = base_model.embedding.astype(np.float32)
277
-
278
- trainable = StaticModelForFineTuning.from_static_model(
279
- model=base_model,
280
- out_dim=PCA_DIMS,
281
- loss=Loss("cosine"),
282
- )
283
- logger.info("Training tokenlearn (lr=%g, max_epochs=%d, batch=%d) ...",
284
- TOKENLEARN_LR, TOKENLEARN_MAX_EPOCHS, TOKENLEARN_BATCH_SIZE)
285
- trainable.fit(
286
- X=train_txt,
287
- y=torch.from_numpy(train_vec.astype(np.float32)),
288
- batch_size=TOKENLEARN_BATCH_SIZE,
289
- learning_rate=TOKENLEARN_LR,
290
- max_epochs=TOKENLEARN_MAX_EPOCHS,
291
- early_stopping_patience=5,
292
- use_wandb=False,
293
- )
294
- logger.info("Tokenlearn training complete.")
295
-
296
- trained_model = trainable.to_static_model()
297
- trained_model = apply_post_sif(trained_model, pca_dims=PCA_DIMS, sif_coefficient=SIF_COEFFICIENT)
298
-
299
- save_path.mkdir(parents=True, exist_ok=True)
300
- trained_model.save_pretrained(str(save_path))
301
- logger.info("Tokenlearn model saved to %s", save_path)
302
-
303
-
304
- # ---------------------------------------------------------------------------
305
- # Step 3: Contrastive fine-tuning (MNRL)
306
- # ---------------------------------------------------------------------------
307
-
308
- def run_contrastive(base_model_path: Path, save_path: Path) -> None:
309
- random.seed(CONTRASTIVE_SEED)
310
-
311
- logger.info("Streaming CornStack pairs (%d/lang × %d langs) ...",
312
- CONTRASTIVE_MAX_PER_LANGUAGE, len(CORNSTACK_DATASETS))
313
- all_queries: list[str] = []
314
- all_docs: list[str] = []
315
- for lang, hf_name in CORNSTACK_DATASETS.items():
316
- hf_ds = load_dataset(hf_name, split="train", streaming=True)
317
- hf_ds = hf_ds.shuffle(seed=CONTRASTIVE_SEED, buffer_size=10_000)
318
- kept = 0
319
- seen_q: set[str] = set()
320
- seen_d: set[str] = set()
321
- for row in hf_ds:
322
- q, d = row.get("query"), row.get("document")
323
- if not isinstance(q, str) or not isinstance(d, str):
324
- continue
325
- if len(q) < 32 or len(d) < 32:
326
- continue
327
- if q in seen_q or d in seen_d:
328
- continue
329
- seen_q.add(q)
330
- seen_d.add(d)
331
- all_queries.append(q)
332
- all_docs.append(d)
333
- kept += 1
334
- if kept >= CONTRASTIVE_MAX_PER_LANGUAGE:
335
- break
336
- logger.info(" %s: %d pairs", lang, kept)
337
-
338
- logger.info("Total pairs: %d", len(all_queries))
339
- train_dataset = Dataset.from_dict({"anchor": all_queries, "positive": all_docs})
340
-
341
- static_embedding = StaticEmbedding.from_model2vec(str(base_model_path))
342
- model = SentenceTransformer(modules=[static_embedding])
343
- loss = MultipleNegativesRankingLoss(model)
344
-
345
- training_args = SentenceTransformerTrainingArguments(
346
- output_dir=str(save_path) + "-checkpoints",
347
- num_train_epochs=CONTRASTIVE_EPOCHS,
348
- per_device_train_batch_size=CONTRASTIVE_BATCH_SIZE,
349
- learning_rate=CONTRASTIVE_LR,
350
- warmup_steps=0.1,
351
- fp16=False,
352
- bf16=False,
353
- batch_sampler=BatchSamplers.NO_DUPLICATES,
354
- save_strategy="no",
355
- logging_steps=100,
356
- logging_first_step=True,
357
- report_to=[],
358
- )
359
- logger.info("Training contrastive (lr=%g, epochs=%d, batch=%d) ...",
360
- CONTRASTIVE_LR, CONTRASTIVE_EPOCHS, CONTRASTIVE_BATCH_SIZE)
361
-
362
- trainer = SentenceTransformerTrainer(
363
- model=model, args=training_args, train_dataset=train_dataset, loss=loss,
364
- )
365
- trainer.train()
366
- logger.info("Contrastive training complete.")
367
-
368
- base_m2v = StaticModel.from_pretrained(str(base_model_path), force_download=False)
369
- base_m2v.embedding = model[0].embedding.weight.detach().cpu().float().numpy()
370
-
371
- final_model = apply_post_sif(base_m2v, pca_dims=PCA_DIMS, sif_coefficient=SIF_COEFFICIENT)
372
-
373
- save_path.mkdir(parents=True, exist_ok=True)
374
- final_model.save_pretrained(str(save_path))
375
- logger.info("Final model saved to %s", save_path)
376
-
377
-
378
- # ---------------------------------------------------------------------------
379
- # Main
380
- # ---------------------------------------------------------------------------
381
-
382
- if __name__ == "__main__":
383
- distilled_path = OUTPUT_DIR / "potion-code-16M-distilled"
384
- tokenlearn_path = OUTPUT_DIR / "potion-code-16M-tokenlearn"
385
- contrastive_path = OUTPUT_DIR / "potion-code-16M-contrastive"
386
-
387
- logger.info("=== Step 1/3: Distill ===")
388
- run_distill(save_path=distilled_path)
389
-
390
- logger.info("=== Step 2/3: Tokenlearn ===")
391
- run_tokenlearn(base_model_path=distilled_path, save_path=tokenlearn_path)
392
-
393
- logger.info("=== Step 3/3: Contrastive ===")
394
- run_contrastive(base_model_path=tokenlearn_path, save_path=contrastive_path)
395
-
396
- logger.info("Done. Final model: %s", contrastive_path)
397
  ```
398
 
399
  ## Citation
 
84
 
85
  ## Reproducibility
86
 
87
+ The full training pipeline (distill tokenlearn contrastive) is in [`train.py`](./train.py). It requires `minishlab/tokenlearn-cornstack-docs-coderankembed` and `minishlab/tokenlearn-cornstack-queries-coderankembed` (20k samples per language used).
88
 
89
+ ```
90
+ pip install model2vec tokenlearn sentence-transformers datasets skeletoken einops
91
+ python train.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  ```
93
 
94
  ## Citation