GoodWordsOnly / src /fewshot_data.py
Jathin Chetty
Vercel-ready API version without model artifacts
f7a8d72
raw
history blame contribute delete
598 Bytes
import random
from src.dataset import load_data
def get_few_shot_data(k=8):
texts, labels = load_data()
hate = []
non_hate = []
for t, l in zip(texts, labels):
if l == 1:
hate.append((t, l))
else:
non_hate.append((t, l))
# shuffle
random.shuffle(hate)
random.shuffle(non_hate)
# pick k samples each
few_shot = hate[:k] + non_hate[:k]
random.shuffle(few_shot)
texts = [x[0] for x in few_shot]
labels = [x[1] for x in few_shot]
print(f"Few-shot dataset size: {len(texts)}")
return texts, labels