File size: 1,874 Bytes
ffcfc75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from datasets import Dataset as HFDataset
from torch.utils.data import Dataset as TorchDataset
from typing import Dict, Any, Union, List


def truncate_dna(
    example: Dict[str, Any], truncate_dna_per_side: int = 1024
) -> Dict[str, Any]:
    """
    Truncate DNA sequences by removing a specified number of base pairs from both ends.
    If the sequence is too short, it will return the middle portion.
    """
    for key in ["reference_sequence", "variant_sequence"]:
        sequence = example[key]
        seq_len = len(sequence)

        if seq_len > 2 * truncate_dna_per_side + 8:
            example[key] = sequence[truncate_dna_per_side:-truncate_dna_per_side]

    return example


def torch_to_hf_dataset(torch_dataset: TorchDataset) -> HFDataset:
    """
    Convert a PyTorch Dataset to a Hugging Face Dataset.

    This function takes a PyTorch Dataset and converts it to a Hugging Face Dataset
    by extracting all items and organizing them into a dictionary structure that
    can be used to create a Hugging Face Dataset.

    Args:
        torch_dataset: A PyTorch Dataset object to be converted

    Returns:
        A Hugging Face Dataset containing the same data as the input PyTorch Dataset
    """
    # Get first item to determine structure
    if len(torch_dataset) == 0:
        return HFDataset.from_dict({})

    first_item = torch_dataset[0]

    # Initialize dictionary based on first item's keys
    data_dict = (
        {k: [] for k in first_item.keys()}
        if isinstance(first_item, dict)
        else {"data": []}
    )

    # Populate dictionary
    for i in range(len(torch_dataset)):
        item = torch_dataset[i]
        if isinstance(item, dict):
            for k in data_dict:
                data_dict[k].append(item[k])
        else:
            data_dict["data"].append(item)

    return HFDataset.from_dict(data_dict)