Spaces:
Sleeping
Sleeping
| import requests | |
| import tensorflow as tf | |
| import pandas as pd | |
| import numpy as np | |
| from operator import add | |
| from functools import reduce | |
| import random | |
| import tabulate | |
| from keras import Model | |
| from keras import regularizers | |
| from keras.optimizers import Adam | |
| from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax | |
| from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout | |
| from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape | |
| from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer | |
| from keras.models import load_model | |
| from keras.callbacks import EarlyStopping, ReduceLROnPlateau | |
| from Bio import SeqIO | |
| from Bio.SeqRecord import SeqRecord | |
| from Bio.SeqFeature import SeqFeature, FeatureLocation | |
| from Bio.Seq import Seq | |
| import cyvcf2 | |
| import parasail | |
| import re | |
| ntmap = {'A': (1, 0, 0, 0), | |
| 'C': (0, 1, 0, 0), | |
| 'G': (0, 0, 1, 0), | |
| 'T': (0, 0, 0, 1) | |
| } | |
| def get_seqcode(seq): | |
| return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1)) | |
| class PositionalEncoding(Layer): | |
| def __init__(self, sequence_len=None, embedding_dim=None,**kwargs): | |
| super(PositionalEncoding, self).__init__() | |
| self.sequence_len = sequence_len | |
| self.embedding_dim = embedding_dim | |
| def call(self, x): | |
| position_embedding = np.array([ | |
| [pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)] | |
| for pos in range(self.sequence_len)]) | |
| position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i | |
| position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1 | |
| position_embedding = tf.cast(position_embedding, dtype=tf.float32) | |
| return position_embedding+x | |
| def get_config(self): | |
| config = super().get_config().copy() | |
| config.update({ | |
| 'sequence_len' : self.sequence_len, | |
| 'embedding_dim' : self.embedding_dim, | |
| }) | |
| return config | |
| def MultiHeadAttention_model(input_shape): | |
| input = Input(shape=input_shape) | |
| conv1 = Conv1D(256, 3, activation="relu")(input) | |
| pool1 = AveragePooling1D(2)(conv1) | |
| drop1 = Dropout(0.4)(pool1) | |
| conv2 = Conv1D(256, 3, activation="relu")(drop1) | |
| pool2 = AveragePooling1D(2)(conv2) | |
| drop2 = Dropout(0.4)(pool2) | |
| lstm = Bidirectional(LSTM(128, | |
| dropout=0.5, | |
| activation='tanh', | |
| return_sequences=True, | |
| kernel_regularizer=regularizers.l2(0.01)))(drop2) | |
| pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm) | |
| atten = MultiHeadAttention(num_heads=2, | |
| key_dim=64, | |
| dropout=0.2, | |
| kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding) | |
| flat = Flatten()(atten) | |
| dense1 = Dense(512, | |
| kernel_regularizer=regularizers.l2(1e-4), | |
| bias_regularizer=regularizers.l2(1e-4), | |
| activation="relu")(flat) | |
| drop3 = Dropout(0.1)(dense1) | |
| dense2 = Dense(128, | |
| kernel_regularizer=regularizers.l2(1e-4), | |
| bias_regularizer=regularizers.l2(1e-4), | |
| activation="relu")(drop3) | |
| drop4 = Dropout(0.1)(dense2) | |
| dense3 = Dense(256, | |
| kernel_regularizer=regularizers.l2(1e-4), | |
| bias_regularizer=regularizers.l2(1e-4), | |
| activation="relu")(drop4) | |
| drop5 = Dropout(0.1)(dense3) | |
| output = Dense(1, activation="linear")(drop5) | |
| model = Model(inputs=[input], outputs=[output]) | |
| return model | |
| def fetch_ensembl_transcripts(gene_symbol): | |
| url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json" | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| gene_data = response.json() | |
| if 'Transcript' in gene_data: | |
| return gene_data['Transcript'] | |
| else: | |
| print("No transcripts found for gene:", gene_symbol) | |
| return None | |
| else: | |
| print(f"Error fetching gene data from Ensembl: {response.text}") | |
| return None | |
| def fetch_ensembl_sequence(transcript_id): | |
| url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json" | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| sequence_data = response.json() | |
| if 'seq' in sequence_data: | |
| return sequence_data['seq'] | |
| else: | |
| print("No sequence found for transcript:", transcript_id) | |
| return None | |
| else: | |
| print(f"Error fetching sequence data from Ensembl: {response.text}") | |
| return None | |
| def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="NGG", target_length=20): | |
| targets = [] | |
| len_sequence = len(sequence) | |
| #complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'} | |
| dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'} | |
| for i in range(len_sequence - len(pam) + 1): | |
| if sequence[i + 1:i + 3] == pam[1:]: | |
| if i >= target_length: | |
| target_seq = sequence[i - target_length:i + 3] | |
| if strand == -1: | |
| tar_start = end - (i + 2) | |
| tar_end = end - (i - target_length) | |
| #seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1] | |
| else: | |
| tar_start = start + i - target_length | |
| tar_end = start + i + 3 - 1 | |
| #seq_in_ref = target_seq | |
| gRNA = ''.join([dnatorna[base] for base in sequence[i - target_length:i]]) | |
| #targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref]) | |
| targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id]) | |
| return targets | |
| # Function to predict on-target efficiency and format output | |
| def format_prediction_output(targets, model_path): | |
| model = MultiHeadAttention_model(input_shape=(23, 4)) | |
| model.load_weights(model_path) | |
| formatted_data = [] | |
| for target in targets: | |
| # Encode the gRNA sequence | |
| encoded_seq = get_seqcode(target[0]) | |
| # Predict on-target efficiency using the model | |
| prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0]) | |
| if prediction > 100: | |
| prediction = 100 | |
| # Format output | |
| gRNA = target[1] | |
| chr = target[2] | |
| start = target[3] | |
| end = target[4] | |
| strand = target[5] | |
| transcript_id = target[6] | |
| exon_id = target[7] | |
| #seq_in_ref = target[8] | |
| #formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction[0]]) | |
| formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction]) | |
| return formatted_data | |
| def process_gene(gene_symbol, model_path): | |
| # Fetch transcripts for the given gene symbol | |
| transcripts = fetch_ensembl_transcripts(gene_symbol) | |
| results = [] | |
| all_exons = [] # To accumulate all exons | |
| all_gene_sequences = [] # To accumulate all gene sequences | |
| if transcripts: | |
| for transcript in transcripts: | |
| Exons = transcript['Exon'] | |
| all_exons.extend(Exons) # Add all exons from this transcript to the list | |
| transcript_id = transcript['id'] | |
| for exon in Exons: | |
| exon_id = exon['id'] | |
| gene_sequence = fetch_ensembl_sequence(exon_id) | |
| if gene_sequence: | |
| all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list | |
| start = exon['start'] | |
| end = exon['end'] | |
| strand = exon['strand'] | |
| chr = exon['seq_region_name'] | |
| # Find potential CRISPR targets within the exon | |
| targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id) | |
| if targets: | |
| # Format the prediction output for the targets found | |
| formatted_data = format_prediction_output(targets, model_path) | |
| results.extend(formatted_data) # Append results | |
| else: | |
| print(f"Failed to retrieve gene sequence for exon {exon_id}.") | |
| else: | |
| print("Failed to retrieve transcripts.") | |
| # Return the sorted output, combined gene sequences, and all exons | |
| return results, all_gene_sequences, all_exons | |
| def create_genbank_features(data): | |
| features = [] | |
| # If the input data is a DataFrame, convert it to a list of lists | |
| if isinstance(data, pd.DataFrame): | |
| formatted_data = data.values.tolist() | |
| elif isinstance(data, list): | |
| formatted_data = data | |
| else: | |
| raise TypeError("Data should be either a list or a pandas DataFrame.") | |
| for row in formatted_data: | |
| try: | |
| start = int(row[1]) | |
| end = int(row[2]) | |
| except ValueError as e: | |
| print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}") | |
| continue | |
| strand = 1 if row[3] == '+' else -1 | |
| location = FeatureLocation(start=start, end=end, strand=strand) | |
| feature = SeqFeature(location=location, type="misc_feature", qualifiers={ | |
| 'label': row[7], # Use gRNA as the label | |
| 'note': f"Prediction: {row[8]}" # Include the prediction score | |
| }) | |
| features.append(feature) | |
| return features | |
| def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path): | |
| # Ensure gene_sequence is a string before creating Seq object | |
| if not isinstance(gene_sequence, str): | |
| gene_sequence = str(gene_sequence) | |
| features = create_genbank_features(df) | |
| # Now gene_sequence is guaranteed to be a string, suitable for Seq | |
| seq_obj = Seq(gene_sequence) | |
| record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol, | |
| description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features) | |
| record.annotations["molecule_type"] = "DNA" | |
| SeqIO.write(record, output_path, "genbank") | |
| def create_bed_file_from_df(df, output_path): | |
| with open(output_path, 'w') as bed_file: | |
| for index, row in df.iterrows(): | |
| chrom = row["Chr"] | |
| start = int(row["Start Pos"]) | |
| end = int(row["End Pos"]) | |
| strand = '+' if row["Strand"] == '1' else '-' | |
| gRNA = row["gRNA"] | |
| score = str(row["Prediction"]) | |
| # transcript_id is not typically part of the standard BED columns but added here for completeness | |
| transcript_id = row["Transcript"] | |
| # Writing only standard BED columns; additional columns can be appended as needed | |
| bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n") | |
| def create_csv_from_df(df, output_path): | |
| df.to_csv(output_path, index=False) | |