crispr-bert-api / data_loader.py
santu0032's picture
Upload 9 files
f6e15bf verified
Raw
History Blame Contribute Delete
3.03 kB
"""
Data Loader for CRISPR Datasets
Loads sequences from txt files and encodes them for CNN/BERT models
"""
import numpy as np
from sequence_encoder import encode_batch_for_cnn, encode_batch_for_bert
def load_dataset(file_path, max_samples=None):
"""
Load sgRNA-DNA pairs from dataset file.
Args:
file_path (str): Path to dataset file (e.g., 'datasets/I1.txt')
max_samples (int): Maximum number of samples to load (None = all)
Returns:
tuple: (sgrna_list, dna_list, labels)
"""
sgrna_list = []
dna_list = []
labels = []
# Read file line by line
with open(file_path, 'r') as f:
for i, line in enumerate(f):
if max_samples and i >= max_samples:
break
# Skip empty lines
line = line.strip()
if not line:
continue
# Parse line: sgRNA,DNA,label or sgRNA,DNA,label,additional_cols
parts = line.split(',')
if len(parts) >= 3:
sgrna = parts[0]
dna = parts[1]
label = float(parts[2])
sgrna_list.append(sgrna)
dna_list.append(dna)
labels.append(label)
return sgrna_list, dna_list, np.array(labels)
def load_and_encode_for_cnn(file_path, max_samples=None):
"""
Load dataset and encode for CNN model.
All sequences encoded to fixed size (26, 7).
Args:
file_path (str): Path to dataset file
max_samples (int): Maximum samples to load
Returns:
tuple: (X_cnn, y) where X_cnn is shape (n_samples, 26, 7)
"""
# Step 1: Load raw sequences
print(f"Loading dataset from {file_path}...")
sgrna_list, dna_list, labels = load_dataset(file_path, max_samples)
print(f" Loaded {len(sgrna_list)} sequences")
# Step 2: Encode for CNN (fixed size: 26x7)
print(f"Encoding for CNN...")
X_cnn = encode_batch_for_cnn(sgrna_list, dna_list)
print(f" CNN encoded shape: {X_cnn.shape}")
return X_cnn, labels
def load_and_encode_for_bert(file_path, max_samples=None):
"""
Load dataset and encode for BERT model.
All sequences encoded to fixed size (26,).
Args:
file_path (str): Path to dataset file
max_samples (int): Maximum samples to load
Returns:
tuple: (X_bert, y) where X_bert is shape (n_samples, 26)
"""
# Step 1: Load raw sequences
print(f"Loading dataset from {file_path}...")
sgrna_list, dna_list, labels = load_dataset(file_path, max_samples)
print(f" Loaded {len(sgrna_list)} sequences")
# Step 2: Encode for BERT (fixed size: 26)
print(f"Encoding for BERT...")
X_bert = encode_batch_for_bert(sgrna_list, dna_list)
print(f" BERT encoded shape: {X_bert.shape}")
return X_bert, labels