|
|
|
|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import os |
|
|
import pyfaidx |
|
|
import kipoiseq.transforms.functional |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect |
|
|
from rdkit.Chem import rdFingerprintGenerator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ENFORMER_INPUT_SEQ_LENGTH = 49_152 |
|
|
|
|
|
|
|
|
GENOME_FASTA_PATH = "data/hg38.fa" |
|
|
TSS_REGIONS_CSV_PATH = "data/Enformer_genomic_regions_TSSCenteredGenes_FixedOverlapRemoval.csv" |
|
|
|
|
|
|
|
|
PSEUDOBULK_TARGET_DATA_PATH = "data/pseudobulk_dummy.csv" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GenomeOneHotEncoder: |
|
|
""" |
|
|
Encodes DNA sequences into one-hot format using kipoiseq. |
|
|
""" |
|
|
def __init__(self, sequence_length: int = ENFORMER_INPUT_SEQ_LENGTH): |
|
|
self.sequence_length = sequence_length |
|
|
|
|
|
@staticmethod |
|
|
def _one_hot_encode(sequence: str) -> np.ndarray: |
|
|
|
|
|
|
|
|
|
|
|
return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32) |
|
|
|
|
|
def encode(self, seq: str) -> np.ndarray: |
|
|
""" |
|
|
One-hot encodes a DNA sequence using kipoiseq. |
|
|
|
|
|
Args: |
|
|
seq (str): The DNA sequence string. The FastaReader should ensure this |
|
|
sequence is already uppercase and of length ENFORMER_INPUT_SEQ_LENGTH. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: A numpy array, typically (L, 4) for DNA, with one-hot encoded sequence. |
|
|
""" |
|
|
|
|
|
return GenomeOneHotEncoder._one_hot_encode(seq) |
|
|
|
|
|
|
|
|
class FastaReader: |
|
|
""" |
|
|
Reads sequences from a FASTA file using pyfaidx. |
|
|
Handles chromosome boundary conditions by padding with 'N'. |
|
|
""" |
|
|
def __init__(self, fasta_path: str): |
|
|
self.fasta_path = fasta_path |
|
|
self.genome = None |
|
|
try: |
|
|
self.genome = pyfaidx.Fasta(self.fasta_path, sequence_always_upper=True) |
|
|
print(f"Successfully loaded and indexed genome using pyfaidx from: {self.fasta_path}") |
|
|
except pyfaidx.FastaIndexingError as e: |
|
|
print(f"Error: Could not index FASTA file at {self.fasta_path}.") |
|
|
print("Ensure it's a valid FASTA file and the .fai index can be created/read in its directory.") |
|
|
print(f"pyfaidx error: {e}") |
|
|
raise |
|
|
except FileNotFoundError: |
|
|
print(f"Error: FASTA file not found at {self.fasta_path}.") |
|
|
raise |
|
|
|
|
|
def get_sequence(self, chrom: str, start_0based: int, end_0based_exclusive: int) -> str: |
|
|
""" |
|
|
Fetches a DNA sequence for the given 0-based genomic interval. |
|
|
Pads with 'N' if the interval extends beyond chromosome boundaries. |
|
|
|
|
|
Args: |
|
|
chrom (str): Chromosome name (e.g., 'chr1'). |
|
|
start_0based (int): 0-based start coordinate (inclusive). |
|
|
end_0based_exclusive (int): 0-based end coordinate (exclusive). |
|
|
|
|
|
Returns: |
|
|
str: The DNA sequence, padded with 'N's to match the requested length |
|
|
(end_0based_exclusive - start_0based). |
|
|
""" |
|
|
if self.genome is None: |
|
|
raise RuntimeError("FastaReader not properly initialized (pyfaidx missing or genome loading failed).") |
|
|
|
|
|
|
|
|
true_chrom_name = chrom |
|
|
if chrom not in self.genome: |
|
|
alternative_chrom_name = 'chr' + chrom if not chrom.startswith('chr') else chrom.replace('chr', '', 1) |
|
|
if alternative_chrom_name in self.genome: |
|
|
true_chrom_name = alternative_chrom_name |
|
|
else: |
|
|
available_chroms_sample = list(self.genome.keys())[:5] |
|
|
raise ValueError( |
|
|
f"Chromosome '{chrom}' (and alternative '{alternative_chrom_name}') not found in FASTA file. " |
|
|
f"Available chromosomes sample: {available_chroms_sample}..." |
|
|
) |
|
|
|
|
|
chrom_len = len(self.genome[true_chrom_name]) |
|
|
seq_len_requested = end_0based_exclusive - start_0based |
|
|
|
|
|
|
|
|
sequence_parts = [] |
|
|
|
|
|
|
|
|
padding_start_len = 0 |
|
|
if start_0based < 0: |
|
|
padding_start_len = abs(start_0based) |
|
|
sequence_parts.append('N' * padding_start_len) |
|
|
effective_start = 0 |
|
|
else: |
|
|
effective_start = start_0based |
|
|
|
|
|
|
|
|
fetch_len = min(end_0based_exclusive, chrom_len) - effective_start |
|
|
|
|
|
if fetch_len > 0: |
|
|
sequence_parts.append(self.genome[true_chrom_name][effective_start : effective_start + fetch_len].seq) |
|
|
elif effective_start >= chrom_len: |
|
|
pass |
|
|
|
|
|
|
|
|
current_len = sum(len(p) for p in sequence_parts) |
|
|
padding_end_len = seq_len_requested - current_len |
|
|
if padding_end_len > 0: |
|
|
sequence_parts.append('N' * padding_end_len) |
|
|
|
|
|
final_sequence = "".join(sequence_parts) |
|
|
|
|
|
|
|
|
if len(final_sequence) != seq_len_requested: |
|
|
|
|
|
raise RuntimeError( |
|
|
f"Internal error: Final sequence length {len(final_sequence)} for {true_chrom_name}:{start_0based}-{end_0based_exclusive} " |
|
|
f"does not match requested {seq_len_requested}." |
|
|
) |
|
|
return final_sequence |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TahoeDataset(Dataset): |
|
|
""" |
|
|
PyTorch Dataset for loading Tahoe data for Enformer fine-tuning. |
|
|
- Reads genomic regions from a regions CSV. |
|
|
- Reads pseudobulk conditions and expression values from a pseudobulk CSV. |
|
|
- Merges these two data sources based on gene identifiers. |
|
|
- Each sample represents a unique gene-condition pair. |
|
|
- Fetches DNA sequence for the gene, resized to `enformer_input_seq_length`. |
|
|
- One-hot encodes sequence and returns it with the specific expression value. |
|
|
""" |
|
|
ORIGINAL_ENFORMER_WINDOW_SIZE = 196_608 |
|
|
|
|
|
def __init__(self, |
|
|
tss_regions_csv_path: str, |
|
|
genome_fasta_path: str, |
|
|
pseudobulk_data_path: str, |
|
|
enformer_input_seq_length: int = ENFORMER_INPUT_SEQ_LENGTH, |
|
|
regions_csv_gene_col: str = 'gene_name', |
|
|
pseudobulk_csv_gene_col: str = 'gene_id', |
|
|
regions_csv_chr_col: str = 'seqnames', |
|
|
regions_csv_start_col: str = 'starts', |
|
|
regions_csv_end_col: str = 'ends'): |
|
|
super().__init__() |
|
|
|
|
|
self.enformer_input_seq_length = enformer_input_seq_length |
|
|
|
|
|
self.regions_gene_col = regions_csv_gene_col |
|
|
self.pseudobulk_gene_col = pseudobulk_csv_gene_col |
|
|
self.regions_chr_col = regions_csv_chr_col |
|
|
self.regions_start_col = regions_csv_start_col |
|
|
self.regions_end_col = regions_csv_end_col |
|
|
|
|
|
print(f"Initializing TahoeDataset...") |
|
|
print(f" Target model input sequence length: {self.enformer_input_seq_length} bp") |
|
|
print(f" Genomic regions are assumed to define a {self.ORIGINAL_ENFORMER_WINDOW_SIZE} bp window for centering.") |
|
|
|
|
|
|
|
|
print(f" Loading TSS regions from: {tss_regions_csv_path}") |
|
|
try: |
|
|
regions_df = pd.read_csv(tss_regions_csv_path) |
|
|
print(f" Successfully loaded regions CSV with {len(regions_df)} gene region entries.") |
|
|
expected_region_cols = [self.regions_chr_col, self.regions_gene_col, |
|
|
self.regions_start_col, self.regions_end_col] |
|
|
missing_region_cols = [col for col in expected_region_cols if col not in regions_df.columns] |
|
|
if missing_region_cols: |
|
|
raise ValueError(f"Missing columns in regions CSV ('{tss_regions_csv_path}'): {missing_region_cols}. Expected: {expected_region_cols}") |
|
|
except FileNotFoundError: |
|
|
print(f"FATAL ERROR: Regions CSV file not found at {tss_regions_csv_path}") |
|
|
raise |
|
|
except Exception as e: |
|
|
print(f"FATAL ERROR loading or validating regions CSV: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
print(f" Loading pseudobulk targets from: {pseudobulk_data_path}") |
|
|
try: |
|
|
pseudobulk_df = pd.read_csv(pseudobulk_data_path) |
|
|
print(f" Successfully loaded pseudobulk CSV with {len(pseudobulk_df)} condition entries.") |
|
|
expected_pb_cols = [self.pseudobulk_gene_col, 'cell_line', 'drug_id', 'drug_dose', 'expression'] |
|
|
missing_pb_cols = [col for col in expected_pb_cols if col not in pseudobulk_df.columns] |
|
|
if missing_pb_cols: |
|
|
raise ValueError(f"Missing columns in pseudobulk CSV ('{pseudobulk_data_path}'): {missing_pb_cols}. Expected: {expected_pb_cols}") |
|
|
except FileNotFoundError: |
|
|
print(f"FATAL ERROR: Pseudobulk CSV file not found at {pseudobulk_data_path}") |
|
|
raise |
|
|
except Exception as e: |
|
|
print(f"FATAL ERROR loading or validating pseudobulk CSV: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
print(f" Merging genomic regions with pseudobulk target data...") |
|
|
print(f" Regions gene column: '{self.regions_gene_col}', Pseudobulk gene column: '{self.pseudobulk_gene_col}'") |
|
|
|
|
|
regions_df[self.regions_gene_col] = regions_df[self.regions_gene_col].astype(str) |
|
|
pseudobulk_df[self.pseudobulk_gene_col] = pseudobulk_df[self.pseudobulk_gene_col].astype(str) |
|
|
|
|
|
self.samples_df = pd.merge( |
|
|
regions_df, |
|
|
pseudobulk_df, |
|
|
left_on=self.regions_gene_col, |
|
|
right_on=self.pseudobulk_gene_col, |
|
|
how='inner' |
|
|
) |
|
|
|
|
|
if len(self.samples_df) == 0: |
|
|
print("WARNING: The merge operation resulted in an empty DataFrame.") |
|
|
print(f" No common genes found between column '{self.regions_gene_col}' in regions CSV ") |
|
|
print(f" and column '{self.pseudobulk_gene_col}' in pseudobulk CSV.") |
|
|
print(f" Please check that gene identifiers match and are of the same type in both files.") |
|
|
|
|
|
if not regions_df.empty: print(f" Sample gene IDs from regions CSV: {regions_df[self.regions_gene_col].unique()[:5].tolist()}") |
|
|
if not pseudobulk_df.empty: print(f" Sample gene IDs from pseudobulk CSV: {pseudobulk_df[self.pseudobulk_gene_col].unique()[:5].tolist()}") |
|
|
else: |
|
|
print(f" Successfully merged data: {len(self.samples_df)} total samples (gene-condition pairs).") |
|
|
|
|
|
|
|
|
original_region_genes = set(regions_df[self.regions_gene_col].unique()) |
|
|
merged_region_genes = set(self.samples_df[self.regions_gene_col].unique()) |
|
|
dropped_region_genes = original_region_genes - merged_region_genes |
|
|
if dropped_region_genes: |
|
|
print(f" WARNING: {len(dropped_region_genes)} unique gene IDs from the regions CSV ('{self.regions_gene_col}') were not found in the pseudobulk CSV ('{self.pseudobulk_gene_col}') and were dropped.") |
|
|
print(f" Examples of dropped region gene IDs: {list(dropped_region_genes)[:min(5, len(dropped_region_genes))]}") |
|
|
|
|
|
|
|
|
original_pseudobulk_genes = set(pseudobulk_df[self.pseudobulk_gene_col].unique()) |
|
|
|
|
|
merged_pseudobulk_genes = set(self.samples_df[self.regions_gene_col].unique()) |
|
|
|
|
|
final_merged_keys_from_pseudobulk_perspective = set(self.samples_df[self.pseudobulk_gene_col].unique()) |
|
|
dropped_pseudobulk_genes = original_pseudobulk_genes - final_merged_keys_from_pseudobulk_perspective |
|
|
|
|
|
if dropped_pseudobulk_genes: |
|
|
print(f" WARNING: {len(dropped_pseudobulk_genes)} unique gene IDs from the pseudobulk CSV ('{self.pseudobulk_gene_col}') were not found in the regions CSV ('{self.regions_gene_col}') and were dropped.") |
|
|
print(f" Examples of dropped pseudobulk gene IDs: {list(dropped_pseudobulk_genes)[:min(5, len(dropped_pseudobulk_genes))]}") |
|
|
|
|
|
if 'expression' in self.samples_df and self.samples_df['expression'].isnull().any(): |
|
|
print("WARNING: NA values found in 'expression' column after merge. These samples might cause errors or yield NaN targets.") |
|
|
print(" Consider handling these (e.g., fill with a default or drop rows withna(subset=['expression'])).") |
|
|
|
|
|
|
|
|
print(f" Initializing FASTA reader for genome: {genome_fasta_path}") |
|
|
self.fasta_reader = FastaReader(genome_fasta_path) |
|
|
self.encoder = GenomeOneHotEncoder(sequence_length=self.enformer_input_seq_length) |
|
|
print("TahoeDataset initialized successfully.") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples_df) |
|
|
|
|
|
def __getitem__(self, idx: int): |
|
|
if torch.is_tensor(idx): |
|
|
idx = idx.tolist() |
|
|
|
|
|
if not (0 <= idx < len(self.samples_df)): |
|
|
raise IndexError(f"Index {idx} out of bounds for dataset of length {len(self.samples_df)}") |
|
|
|
|
|
sample_info = self.samples_df.iloc[idx] |
|
|
|
|
|
try: |
|
|
chrom = str(sample_info[self.regions_chr_col]) |
|
|
|
|
|
gene_name_for_logging = str(sample_info[self.regions_gene_col]) |
|
|
|
|
|
csv_region_start = int(sample_info[self.regions_start_col]) |
|
|
csv_region_end = int(sample_info[self.regions_end_col]) |
|
|
|
|
|
expression_value = float(sample_info['expression']) |
|
|
except KeyError as e: |
|
|
print(f"FATAL ERROR in __getitem__ (idx {idx}): Missing expected column {e} in merged samples_df.") |
|
|
print(f" Available columns: {self.samples_df.columns.tolist()}") |
|
|
print(f" Sample info for this index: {sample_info.to_dict() if isinstance(sample_info, pd.Series) else sample_info}") |
|
|
raise |
|
|
except ValueError as e: |
|
|
print(f"FATAL ERROR in __getitem__ (idx {idx}): Could not convert data for gene {gene_name_for_logging}. Error: {e}") |
|
|
print(f" Expression value was: '{sample_info.get('expression', 'N/A')}'") |
|
|
raise |
|
|
except Exception as e: |
|
|
print(f"FATAL ERROR in __getitem__ (idx {idx}) for gene {gene_name_for_logging}: An unexpected error occurred: {type(e).__name__} - {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
actual_csv_window_len = csv_region_end - csv_region_start |
|
|
if actual_csv_window_len != self.ORIGINAL_ENFORMER_WINDOW_SIZE: |
|
|
|
|
|
|
|
|
print(f"WARNING for gene {gene_name_for_logging} (idx {idx}): Region {chrom}:{csv_region_start}-{csv_region_end} from CSV " |
|
|
f"has length {actual_csv_window_len}bp, but expected {self.ORIGINAL_ENFORMER_WINDOW_SIZE}bp " |
|
|
f"for the original window definition used for centering. Sequence extraction might be affected if assumptions are wrong.") |
|
|
|
|
|
|
|
|
|
|
|
final_seq_start_0based = csv_region_start |
|
|
final_seq_end_0based_exclusive = csv_region_end |
|
|
|
|
|
|
|
|
|
|
|
if self.enformer_input_seq_length != self.ORIGINAL_ENFORMER_WINDOW_SIZE: |
|
|
|
|
|
|
|
|
original_window_center = csv_region_end - (self.ORIGINAL_ENFORMER_WINDOW_SIZE // 2) |
|
|
|
|
|
half_target_seq_len = self.enformer_input_seq_length // 2 |
|
|
final_seq_start_0based = original_window_center - half_target_seq_len |
|
|
|
|
|
final_seq_end_0based_exclusive = final_seq_start_0based + self.enformer_input_seq_length |
|
|
|
|
|
|
|
|
dna_sequence = self.fasta_reader.get_sequence(chrom, final_seq_start_0based, final_seq_end_0based_exclusive) |
|
|
one_hot_sequence = self.encoder.encode(dna_sequence) |
|
|
one_hot_sequence_tensor = torch.tensor(one_hot_sequence, dtype=torch.float32) |
|
|
|
|
|
|
|
|
target_tensor = torch.tensor([expression_value], dtype=torch.float32) |
|
|
|
|
|
return one_hot_sequence_tensor, target_tensor |
|
|
|
|
|
|
|
|
|
|
|
class TahoeSMILESDataset(Dataset): |
|
|
""" |
|
|
Extends TahoeDataset to also return: |
|
|
- Morgan Fingerprints for the drug |
|
|
- drug dose |
|
|
- target expression |
|
|
""" |
|
|
def __init__(self, |
|
|
regions_csv_path: str, |
|
|
pbulk_parquet_path: str, |
|
|
drug_meta_csv_path: str, |
|
|
fasta_file_path: str, |
|
|
enformer_input_seq_length: int = ENFORMER_INPUT_SEQ_LENGTH, |
|
|
|
|
|
morgan_fp_radius: int = 2, |
|
|
morgan_fp_nbits: int = 2048, |
|
|
|
|
|
regions_gene_col: str = 'gene_name', |
|
|
regions_chr_col: str = 'seqnames', |
|
|
regions_start_col: str = 'starts', |
|
|
regions_end_col: str = 'ends', |
|
|
|
|
|
pbulk_gene_col: str = 'gene_id', |
|
|
pbulk_drug_col: str = 'drug_id', |
|
|
pbulk_dose_col: str = 'drug_dose', |
|
|
pbulk_expr_col: str = 'expression', |
|
|
pbulk_cell_line_col: str = 'cell_line', |
|
|
|
|
|
drug_meta_id_col: str = 'drug', |
|
|
drug_meta_smiles_col: str = 'canonical_smiles', |
|
|
filter_drugs_by_ids: list = None, |
|
|
regions_strand_col: str = None, |
|
|
regions_set_col: str = 'set', |
|
|
target_set: str = None |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.seq_len = enformer_input_seq_length |
|
|
self.morgan_fp_radius = morgan_fp_radius |
|
|
self.morgan_fp_nbits = morgan_fp_nbits |
|
|
|
|
|
self.regions_gene_col = regions_gene_col |
|
|
self.regions_chr_col = regions_chr_col |
|
|
self.regions_start_col = regions_start_col |
|
|
self.regions_end_col = regions_end_col |
|
|
self.regions_set_col = regions_set_col |
|
|
|
|
|
self.pbulk_gene_col = pbulk_gene_col |
|
|
self.pbulk_drug_col = pbulk_drug_col |
|
|
self.pbulk_dose_col = pbulk_dose_col |
|
|
self.pbulk_expr_col = pbulk_expr_col |
|
|
self.pbulk_cell_line_col = pbulk_cell_line_col |
|
|
|
|
|
self.drug_meta_id_col = drug_meta_id_col |
|
|
self.drug_meta_smiles_col= drug_meta_smiles_col |
|
|
|
|
|
self.target_set = target_set |
|
|
|
|
|
|
|
|
self._morgan_gen = rdFingerprintGenerator.GetMorganGenerator( |
|
|
radius=self.morgan_fp_radius, |
|
|
fpSize=self.morgan_fp_nbits |
|
|
) |
|
|
|
|
|
|
|
|
print(f" Loading TSS regions from: {regions_csv_path}") |
|
|
try: |
|
|
regs = pd.read_csv(regions_csv_path) |
|
|
print(f" Successfully loaded regions CSV with {len(regs)} gene region entries.") |
|
|
except FileNotFoundError: |
|
|
print(f"FATAL ERROR: Regions CSV file not found at {regions_csv_path}") |
|
|
raise |
|
|
except Exception as e: |
|
|
print(f"FATAL ERROR loading regions CSV: {e}") |
|
|
raise |
|
|
|
|
|
print(f" Loading pseudobulk targets from: {pbulk_parquet_path} (expected Parquet format)") |
|
|
try: |
|
|
pb = pd.read_parquet(pbulk_parquet_path) |
|
|
print(f" Successfully loaded pseudobulk Parquet file with {len(pb)} entries.") |
|
|
except FileNotFoundError: |
|
|
print(f"FATAL ERROR: Pseudobulk Parquet file not found at {pbulk_parquet_path}") |
|
|
raise |
|
|
except Exception as e: |
|
|
print(f"FATAL ERROR loading or parsing pseudobulk Parquet file: {e}") |
|
|
print(" Ensure the file is a valid Parquet file and you have a Parquet engine like 'pyarrow' or 'fastparquet' installed.") |
|
|
raise |
|
|
|
|
|
|
|
|
regs[self.regions_gene_col] = regs[self.regions_gene_col].astype(str) |
|
|
pb[self.pbulk_gene_col] = pb[self.pbulk_gene_col].astype(str) |
|
|
|
|
|
print(f" Merging genomic regions with pseudobulk target data...") |
|
|
print(f" Regions gene column: '{self.regions_gene_col}', Pseudobulk gene column: '{self.pbulk_gene_col}'") |
|
|
self.samples_df = regs.merge( |
|
|
pb, |
|
|
left_on = self.regions_gene_col, |
|
|
right_on = self.pbulk_gene_col, |
|
|
how = 'inner' |
|
|
) |
|
|
|
|
|
if filter_drugs_by_ids and self.pbulk_drug_col in self.samples_df.columns: |
|
|
print(f" Filtering samples to include only drugs: {filter_drugs_by_ids}") |
|
|
initial_count = len(self.samples_df) |
|
|
self.samples_df = self.samples_df[self.samples_df[self.pbulk_drug_col].isin(filter_drugs_by_ids)] |
|
|
print(f" Retained {len(self.samples_df)} samples after drug filtering (from {initial_count}).") |
|
|
if len(self.samples_df) == 0: |
|
|
print("WARNING: No samples remaining after filtering by drug IDs. Check your filter_drugs_by_ids list and drug IDs in pbulk data.") |
|
|
|
|
|
|
|
|
if self.target_set: |
|
|
if self.regions_set_col in self.samples_df.columns: |
|
|
print(f" Filtering samples for set: '{self.target_set}' using column '{self.regions_set_col}'.") |
|
|
initial_count_set_filter = len(self.samples_df) |
|
|
self.samples_df = self.samples_df[self.samples_df[self.regions_set_col] == self.target_set].copy() |
|
|
print(f" Retained {len(self.samples_df)} samples after filtering for set '{self.target_set}' (from {initial_count_set_filter}).") |
|
|
if len(self.samples_df) == 0: |
|
|
print(f"WARNING: No samples remaining for this dataset instance (target_set='{self.target_set}') after filtering. Check the '{self.regions_set_col}' column in '{regions_csv_path}' for entries matching '{self.target_set}' and their overlap with pseudobulk data.") |
|
|
else: |
|
|
print(f"WARNING: target_set '{self.target_set}' was specified, but the column '{self.regions_set_col}' was not found in the merged DataFrame. No set-specific filtering was applied for this dataset instance. This instance will contain all data that matched other criteria.") |
|
|
|
|
|
|
|
|
print(f" Loading drug metadata from: {drug_meta_csv_path}") |
|
|
try: |
|
|
dm = pd.read_csv(drug_meta_csv_path) |
|
|
print(f" Successfully loaded drug metadata with {len(dm)} entries.") |
|
|
except FileNotFoundError: |
|
|
print(f"FATAL ERROR: Drug metadata CSV not found at {drug_meta_csv_path}") |
|
|
raise |
|
|
except Exception as e: |
|
|
print(f"FATAL ERROR loading drug metadata CSV: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
if self.drug_meta_smiles_col not in dm.columns: |
|
|
raise ValueError(f"SMILES column '{self.drug_meta_smiles_col}' not found in drug metadata.") |
|
|
if self.drug_meta_id_col not in dm.columns: |
|
|
raise ValueError(f"Drug ID column '{self.drug_meta_id_col}' not found in drug metadata.") |
|
|
dm[self.drug_meta_smiles_col] = dm[self.drug_meta_smiles_col].fillna('').astype(str) |
|
|
self.drug_meta = dm.set_index(self.drug_meta_id_col) |
|
|
|
|
|
|
|
|
self.fasta_reader = FastaReader(fasta_file_path) |
|
|
self.encoder = GenomeOneHotEncoder(sequence_length=self.seq_len) |
|
|
print("TahoeSMILESDataset initialized.") |
|
|
|
|
|
def _generate_morgan_fingerprint(self, smiles_string: str) -> np.ndarray: |
|
|
"""Generates a Morgan fingerprint from a SMILES string using the new generator API.""" |
|
|
if not smiles_string: |
|
|
return np.zeros(self.morgan_fp_nbits, dtype=np.float32) |
|
|
try: |
|
|
mol = Chem.MolFromSmiles(smiles_string) |
|
|
if mol: |
|
|
|
|
|
fp_array = self._morgan_gen.GetFingerprintAsNumPy(mol) |
|
|
return fp_array.astype(np.float32) |
|
|
else: |
|
|
return np.zeros(self.morgan_fp_nbits, dtype=np.float32) |
|
|
except Exception as e: |
|
|
return np.zeros(self.morgan_fp_nbits, dtype=np.float32) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples_df) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
row = self.samples_df.iloc[idx] |
|
|
|
|
|
|
|
|
chrom = str(row[self.regions_chr_col]) |
|
|
start = int(row[self.regions_start_col]) |
|
|
end = int(row[self.regions_end_col]) |
|
|
orig = end - start |
|
|
if self.seq_len != orig: |
|
|
center = end - orig//2 |
|
|
half = self.seq_len//2 |
|
|
start, end = center-half, center+half |
|
|
|
|
|
seq = self.fasta_reader.get_sequence(chrom, start, end) |
|
|
oh = self.encoder.encode(seq) |
|
|
seq_tensor = torch.tensor(oh, dtype=torch.float32) |
|
|
|
|
|
|
|
|
drug_id_for_fp = row[self.pbulk_drug_col] |
|
|
smiles_string = '' |
|
|
if drug_id_for_fp in self.drug_meta.index: |
|
|
smiles_string = self.drug_meta.loc[drug_id_for_fp, self.drug_meta_smiles_col] |
|
|
|
|
|
if isinstance(smiles_string, pd.Series): |
|
|
smiles_string = smiles_string.iloc[0] |
|
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
morgan_fp = self._generate_morgan_fingerprint(str(smiles_string)) |
|
|
morgan_fp_tensor = torch.tensor(morgan_fp, dtype=torch.float32) |
|
|
|
|
|
|
|
|
dose_val = float(row[self.pbulk_dose_col]) |
|
|
expression_val = float(row[self.pbulk_expr_col]) |
|
|
|
|
|
dose_tensor = torch.tensor([dose_val], dtype=torch.float32) |
|
|
tgt_tensor = torch.tensor([expression_val], dtype=torch.float32) |
|
|
|
|
|
|
|
|
gene_id_meta = str(row[self.pbulk_gene_col]) |
|
|
drug_id_meta = str(row[self.pbulk_drug_col]) |
|
|
cell_line_meta = str(row[self.pbulk_cell_line_col]) |
|
|
|
|
|
return seq_tensor, morgan_fp_tensor, dose_tensor, tgt_tensor, gene_id_meta, drug_id_meta, cell_line_meta, chrom, start, end |