File size: 2,212 Bytes
4f2b2f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from torch.utils.data import Dataset
import numpy as np
import torch


def generate_bracket(length: int, seq: str = ""):
    import random

    if length == 0:
        return seq
    p = random.randint(0, 1)
    if p == 0 or seq == "":
        return generate_bracket(length - 2, "(" + seq + "(")
    else:
        return seq + generate_bracket(length, "")


class BracketDataset(Dataset):
    def __init__(self, n, length_probs):
        lengths = list(length_probs.keys())
        probs = [length_probs[k] for k in lengths]
        self.data = []

        # Track actual length distribution
        length_counts = {length: 0 for length in lengths}

        for _ in range(n):
            L = int(np.random.choice(lengths, p=probs))
            seq = generate_bracket(L)
            mapped = [1 if c == "(" else 2 for c in seq]
            mapped += [3] * (64 - len(mapped))
            self.data.append(torch.tensor(mapped, dtype=torch.long))

            # Count actual sequence length
            actual_length = len(seq)
            if actual_length in length_counts:
                length_counts[actual_length] += 1
            else:
                length_counts[actual_length] = 1

        # Print length distribution
        print("Length distribution in dataset:")
        for length, count in sorted(length_counts.items()):
            print(f"  Length {length}: {count} sequences ({count/n:.2%})")

    @staticmethod
    def parse_tensor(tensor: torch.Tensor):
        if tensor.dim() == 1:
            result = ""
            mapping = {0: "m", 1: "(", 2: ")", 3: ""}
            for i in range(tensor.size(0)):
                result += mapping[int(tensor[i].item())]
            return result
        elif tensor.dim() == 2:
            return [
                BracketDataset.parse_tensor(tensor[i]) for i in range(tensor.size(0))
            ]
        else:
            raise ValueError("input cannot have dimension more than 2")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


if __name__ == "__main__":
    ds = BracketDataset(1000, {4: 0.1, 8: 0.2, 32: 0.3, 64: 0.4})

    for i in range(len(ds)):
        print(ds[i])