Update train_script.py
Browse files- train_script.py +2 -1
train_script.py
CHANGED
|
@@ -28,7 +28,7 @@ model = SentenceTransformer(
|
|
| 28 |
# 3. Load a dataset to finetune on
|
| 29 |
dataset = load_dataset("sentence-transformers/gooaq", split="train")
|
| 30 |
dataset = dataset.add_column("id", range(len(dataset)))
|
| 31 |
-
dataset_dict = dataset.train_test_split(test_size=10_000)
|
| 32 |
train_dataset: Dataset = dataset_dict["train"]
|
| 33 |
eval_dataset: Dataset = dataset_dict["test"]
|
| 34 |
|
|
@@ -62,6 +62,7 @@ args = SentenceTransformerTrainingArguments(
|
|
| 62 |
# 6. (Optional) Create an evaluator & evaluate the base model
|
| 63 |
# The full corpus, but only the evaluation queries
|
| 64 |
# corpus = dict(zip(dataset["id"], dataset["answer"]))
|
|
|
|
| 65 |
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
|
| 66 |
corpus = (
|
| 67 |
{qid: dataset[qid]["answer"] for qid in queries} |
|
|
|
|
| 28 |
# 3. Load a dataset to finetune on
|
| 29 |
dataset = load_dataset("sentence-transformers/gooaq", split="train")
|
| 30 |
dataset = dataset.add_column("id", range(len(dataset)))
|
| 31 |
+
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
|
| 32 |
train_dataset: Dataset = dataset_dict["train"]
|
| 33 |
eval_dataset: Dataset = dataset_dict["test"]
|
| 34 |
|
|
|
|
| 62 |
# 6. (Optional) Create an evaluator & evaluate the base model
|
| 63 |
# The full corpus, but only the evaluation queries
|
| 64 |
# corpus = dict(zip(dataset["id"], dataset["answer"]))
|
| 65 |
+
random.seed(12)
|
| 66 |
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
|
| 67 |
corpus = (
|
| 68 |
{qid: dataset[qid]["answer"] for qid in queries} |
|