yuccaaa's picture
Add files using upload-large-folder tool
ffcfc75 verified
from datasets import Dataset as HFDataset
from torch.utils.data import Dataset as TorchDataset
from typing import Dict, Any, Union, List
def truncate_dna(
example: Dict[str, Any], truncate_dna_per_side: int = 1024
) -> Dict[str, Any]:
"""
Truncate DNA sequences by removing a specified number of base pairs from both ends.
If the sequence is too short, it will return the middle portion.
"""
for key in ["reference_sequence", "variant_sequence"]:
sequence = example[key]
seq_len = len(sequence)
if seq_len > 2 * truncate_dna_per_side + 8:
example[key] = sequence[truncate_dna_per_side:-truncate_dna_per_side]
return example
def torch_to_hf_dataset(torch_dataset: TorchDataset) -> HFDataset:
"""
Convert a PyTorch Dataset to a Hugging Face Dataset.
This function takes a PyTorch Dataset and converts it to a Hugging Face Dataset
by extracting all items and organizing them into a dictionary structure that
can be used to create a Hugging Face Dataset.
Args:
torch_dataset: A PyTorch Dataset object to be converted
Returns:
A Hugging Face Dataset containing the same data as the input PyTorch Dataset
"""
# Get first item to determine structure
if len(torch_dataset) == 0:
return HFDataset.from_dict({})
first_item = torch_dataset[0]
# Initialize dictionary based on first item's keys
data_dict = (
{k: [] for k in first_item.keys()}
if isinstance(first_item, dict)
else {"data": []}
)
# Populate dictionary
for i in range(len(torch_dataset)):
item = torch_dataset[i]
if isinstance(item, dict):
for k in data_dict:
data_dict[k].append(item[k])
else:
data_dict["data"].append(item)
return HFDataset.from_dict(data_dict)