File size: 598 Bytes
f7a8d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import random
from src.dataset import load_data


def get_few_shot_data(k=8):
    texts, labels = load_data()

    hate = []
    non_hate = []

    for t, l in zip(texts, labels):
        if l == 1:
            hate.append((t, l))
        else:
            non_hate.append((t, l))

    # shuffle
    random.shuffle(hate)
    random.shuffle(non_hate)

    # pick k samples each
    few_shot = hate[:k] + non_hate[:k]

    random.shuffle(few_shot)

    texts = [x[0] for x in few_shot]
    labels = [x[1] for x in few_shot]

    print(f"Few-shot dataset size: {len(texts)}")

    return texts, labels