Commit
·
0c1d5b6
1
Parent(s):
93f233e
update the multi labelling prompt
Browse files
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
|
@@ -105,7 +105,7 @@ def generate_dataset(
|
|
| 105 |
is_sample=is_sample,
|
| 106 |
)
|
| 107 |
labeller_generator = get_labeller_generator(
|
| 108 |
-
system_prompt=f"{system_prompt} {', '.join(labels)}",
|
| 109 |
labels=labels,
|
| 110 |
multi_label=multi_label,
|
| 111 |
)
|
|
@@ -134,7 +134,6 @@ def generate_dataset(
|
|
| 134 |
else:
|
| 135 |
k = 1
|
| 136 |
|
| 137 |
-
print(k)
|
| 138 |
sampled_labels = random.sample(labels, min(k, len(labels)))
|
| 139 |
random.shuffle(sampled_labels)
|
| 140 |
inputs.append(
|
|
|
|
| 105 |
is_sample=is_sample,
|
| 106 |
)
|
| 107 |
labeller_generator = get_labeller_generator(
|
| 108 |
+
system_prompt=f"{system_prompt}. Potential labels: {', '.join(labels)}",
|
| 109 |
labels=labels,
|
| 110 |
multi_label=multi_label,
|
| 111 |
)
|
|
|
|
| 134 |
else:
|
| 135 |
k = 1
|
| 136 |
|
|
|
|
| 137 |
sampled_labels = random.sample(labels, min(k, len(labels)))
|
| 138 |
random.shuffle(sampled_labels)
|
| 139 |
inputs.append(
|