Spaces:
No application file
No application file
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from Bio import SeqIO, AlignIO | |
| from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor, DistanceMatrix | |
| from Bio.Phylo.BaseTree import Tree | |
| from Bio.Align import MultipleSeqAlignment | |
| from Bio.Seq import Seq | |
| from Bio.SeqRecord import SeqRecord | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import LabelEncoder | |
| import warnings | |
| import os | |
| import sys | |
| import re | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Optional | |
| import itertools | |
| import argparse | |
| warnings.filterwarnings('ignore') | |
| class PhylogeneticTreeAnalyzer: | |
| """Analyzes phylogenetic relationships using ML-based sequence similarity and tree construction.""" | |
| def __init__(self): | |
| self.data = None | |
| self.query_sequence = None | |
| self.query_id = None | |
| self.matching_percentage = 95.0 | |
| self.actual_percentage = None | |
| self.matched_sequences = [] | |
| self.tree_structure = {} | |
| self.similarity_scores = {} | |
| self.ai_model = None # ML model for sequence classification | |
| self.genotype_model = None # Model for genotype prediction | |
| self.label_encoder = LabelEncoder() # Encoder for ML labels | |
| self.genotype_label_encoder = LabelEncoder() # Encoder for genotype labels | |
| self.ml_tree = None | |
| self.ml_alignment = None | |
| self.ml_results = {} | |
| self.horizontal_line_tracker = [] | |
| self.query_ml_group = None | |
| self.base_horizontal_length = 1.2 | |
| self.ml_model_accuracy = None # Accuracy of ML model | |
| self.genotype_model_accuracy = None # Accuracy of genotype model | |
| # --- Data Loading --- | |
| def load_data(self, data_file: str) -> bool: | |
| """Loads sequence data from a CSV file.""" | |
| try: | |
| self.data = pd.read_csv(data_file) | |
| print(f"✓ Data loaded: {len(self.data)} sequences, " | |
| f"{self.data['ML'].nunique()} ML groups, " | |
| f"{self.data['Genotype'].nunique()} genotypes") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading data: {e}") | |
| return False | |
| # --- Model Training --- | |
| def train_ai_model(self) -> bool: | |
| """Trains RandomForest models for ML group and genotype prediction.""" | |
| try: | |
| if len(self.data) < 10: | |
| print("⚠️ Insufficient data for training (minimum 10 samples)") | |
| return False | |
| print("🤖 Training AI models...") | |
| f_gene_sequences = self.data['F-gene'].fillna('').astype(str) | |
| features = [] | |
| for seq in f_gene_sequences: | |
| seq_clean = re.sub(r'[^ATGC]', '', seq.upper()) | |
| if len(seq_clean) < 3: | |
| features.append([0] * 100) | |
| continue | |
| feature_vector = [] | |
| kmers_3 = [seq_clean[i:i+3] for i in range(len(seq_clean)-2)] | |
| kmer_counts_3 = {kmer: kmers_3.count(kmer) for kmer in set(kmers_3)} | |
| kmers_4 = [seq_clean[i:i+4] for i in range(len(seq_clean)-3)] | |
| kmer_counts_4 = {kmer: kmers_4.count(kmer) for kmer in set(kmers_4)} | |
| all_3mers = [''.join(p) for p in itertools.product('ATGC', repeat=3)] | |
| all_4mers = [''.join(p) for p in itertools.product('ATGC', repeat=4)] | |
| feature_vector.extend([kmer_counts_3.get(kmer, 0) for kmer in all_3mers[:50]]) | |
| feature_vector.extend([kmer_counts_4.get(kmer, 0) for kmer in all_4mers[:50]]) | |
| features.append(feature_vector) | |
| X = np.array(features) | |
| # Train ML model | |
| ml_targets = self.label_encoder.fit_transform(self.data['ML'].fillna('Unknown')) | |
| if len(np.unique(ml_targets)) < 2: | |
| print("⚠️ Need at least 2 ML classes for training") | |
| return False | |
| X_train, X_test, y_train, y_test = train_test_split(X, ml_targets, test_size=0.2, random_state=42) | |
| self.ai_model = RandomForestClassifier(n_estimators=100, random_state=42) | |
| self.ai_model.fit(X_train, y_train) | |
| self.ml_model_accuracy = self.ai_model.score(X_test, y_test) | |
| print(f"✓ ML model trained with accuracy: {self.ml_model_accuracy:.2%}") | |
| # Train genotype model | |
| genotype_targets = self.genotype_label_encoder.fit_transform(self.data['Genotype'].fillna('Unknown')) | |
| if len(np.unique(genotype_targets)) >= 2: | |
| X_train, X_test, y_train, y_test = train_test_split(X, genotype_targets, test_size=0.2, random_state=42) | |
| self.genotype_model = RandomForestClassifier(n_estimators=100, random_state=42) | |
| self.genotype_model.fit(X_train, y_train) | |
| self.genotype_model_accuracy = self.genotype_model.score(X_test, y_test) | |
| print(f"✓ Genotype model trained with accuracy: {self.genotype_model_accuracy:.2%}") | |
| return True | |
| except Exception as e: | |
| print(f"Error training models: {e}") | |
| return False | |
| def predict_ml_group(self, sequence: str) -> str: | |
| """Predicts ML group for a sequence using the trained model.""" | |
| try: | |
| if not self.ai_model: | |
| return "Unknown" | |
| seq_clean = re.sub(r'[^ATGC]', '', sequence.upper()) | |
| if len(seq_clean) < 3: | |
| return "Unknown" | |
| feature_vector = [] | |
| kmers_3 = [seq_clean[i:i+3] for i in range(len(seq_clean)-2)] | |
| kmer_counts_3 = {kmer: kmers_3.count(kmer) for kmer in set(kmers_3)} | |
| kmers_4 = [seq_clean[i:i+4] for i in range(len(seq_clean)-3)] | |
| kmer_counts_4 = {kmer: kmers_4.count(kmer) for kmer in set(kmers_4)} | |
| all_3mers = [''.join(p) for p in itertools.product('ATGC', repeat=3)] | |
| all_4mers = [''.join(p) for p in itertools.product('ATGC', repeat=4)] | |
| feature_vector.extend([kmer_counts_3.get(kmer, 0) for kmer in all_3mers[:50]]) | |
| feature_vector.extend([kmer_counts_4.get(kmer, 0) for kmer in all_4mers[:50]]) | |
| X = np.array([feature_vector]) | |
| ml_pred = self.label_encoder.inverse_transform(self.ai_model.predict(X))[0] | |
| return ml_pred | |
| except Exception as e: | |
| print(f"Error predicting ML group: {e}") | |
| return "Unknown" | |
| def predict_genotype(self, sequence: str) -> str: | |
| """Predicts genotype for a sequence using the trained model.""" | |
| try: | |
| if not self.genotype_model: | |
| return "Unknown" | |
| seq_clean = re.sub(r'[^ATGC]', '', sequence.upper()) | |
| if len(seq_clean) < 3: | |
| return "Unknown" | |
| feature_vector = [] | |
| kmers_3 = [seq_clean[i:i+3] for i in range(len(seq_clean)-2)] | |
| kmer_counts_3 = {kmer: kmers_3.count(kmer) for kmer in set(kmers_3)} | |
| kmers_4 = [seq_clean[i:i+4] for i in range(len(seq_clean)-3)] | |
| kmer_counts_4 = {kmer: kmers_4.count(kmer) for kmer in set(kmers_4)} | |
| all_3mers = [''.join(p) for p in itertools.product('ATGC', repeat=3)] | |
| all_4mers = [''.join(p) for p in itertools.product('ATGC', repeat=4)] | |
| feature_vector.extend([kmer_counts_3.get(kmer, 0) for kmer in all_3mers[:50]]) | |
| feature_vector.extend([kmer_counts_4.get(kmer, 0) for kmer in all_4mers[:50]]) | |
| X = np.array([feature_vector]) | |
| genotype_pred = self.genotype_label_encoder.inverse_transform(self.genotype_model.predict(X))[0] | |
| return genotype_pred | |
| except Exception as e: | |
| print(f"Error predicting genotype: {e}") | |
| return "Unknown" | |
| # --- Sequence Processing --- | |
| def find_query_sequence(self, query_input: str) -> bool: | |
| """Identifies query sequence by accession number, F-gene, or as a novel sequence.""" | |
| try: | |
| query_input = query_input.strip() | |
| if query_input in self.data['Accession Number'].values: | |
| self.query_id = query_input | |
| query_row = self.data[self.data['Accession Number'] == query_input].iloc[0] | |
| self.query_sequence = query_row['F-gene'] | |
| print(f"✓ Query found by accession: {query_input}, ML: {query_row['ML']}, Genotype: {query_row['Genotype']}") | |
| return True | |
| query_clean = re.sub(r'[^ATGC]', '', str(query_input).upper()) | |
| if query_clean in self.data['F-gene'].values: | |
| query_row = self.data[self.data['F-gene'] == query_clean].iloc[0] | |
| self.query_id = query_row['Accession Number'] | |
| self.query_sequence = query_clean | |
| print(f"✓ Query matched to accession: {self.query_id}, ML: {query_row['ML']}, Genotype: {query_row['Genotype']}") | |
| return True | |
| if len(query_clean) >= 10: | |
| self.query_id = f"QUERY_{hash(query_clean) % 100000:05d}" | |
| self.query_sequence = query_clean | |
| predicted_ml = self.predict_ml_group(query_clean) | |
| predicted_genotype = self.predict_genotype(query_clean) | |
| print(f"✓ Novel query accepted: {self.query_id}, Length: {len(query_clean)}, " | |
| f"Predicted ML: {predicted_ml}, Predicted Genotype: {predicted_genotype}") | |
| return True | |
| print(f"✗ Invalid query: Too short (<10) or not found") | |
| return False | |
| except Exception as e: | |
| print(f"Error processing query: {e}") | |
| return False | |
| def calculate_f_gene_similarity(self, seq1: str, seq2: str) -> float: | |
| """Calculates similarity between two sequences using k-mer analysis.""" | |
| try: | |
| if not seq1 or not seq2: | |
| return 0.0 | |
| seq1 = re.sub(r'[^ATGC]', '', str(seq1).upper()) | |
| seq2 = re.sub(r'[^ATGC]', '', str(seq2).upper()) | |
| if len(seq1) == 0 or len(seq2) == 0: | |
| return 0.0 | |
| k = 5 | |
| kmers1 = set(seq1[i:i+k] for i in range(len(seq1)-k+1) if len(seq1[i:i+k]) == k) | |
| kmers2 = set(seq2[i:i+k] for i in range(len(seq2)-k+1) if len(seq2[i:i+k]) == k) | |
| if len(kmers1) == 0 and len(kmers2) == 0: | |
| return 100.0 | |
| if len(kmers1) == 0 or len(kmers2) == 0: | |
| return 0.0 | |
| intersection = len(kmers1.intersection(kmers2)) | |
| union = len(kmers1.union(kmers2)) | |
| return round((intersection / union) * 100, 2) if union > 0 else 0.0 | |
| except Exception as e: | |
| print(f"Error calculating similarity: {e}") | |
| return 0.0 | |
| def find_similar_sequences(self, target_percentage: float) -> Tuple[List[str], float]: | |
| """Finds sequences similar to the query sequence.""" | |
| try: | |
| print(f"🔍 Finding sequences with {target_percentage}% similarity...") | |
| similarities = [] | |
| for idx, row in self.data.iterrows(): | |
| if row['Accession Number'] == self.query_id: | |
| continue | |
| similarity = self.calculate_f_gene_similarity(self.query_sequence, row['F-gene']) | |
| similarities.append({ | |
| 'id': row['Accession Number'], | |
| 'similarity': similarity, | |
| 'ml': row.get('ML', 'Unknown'), | |
| 'genotype': row.get('Genotype', 'Unknown') | |
| }) | |
| if not similarities: | |
| print("❌ No valid sequences for comparison") | |
| return [], target_percentage | |
| similarities.sort(key=lambda x: x['similarity'], reverse=True) | |
| target_range = 2.0 | |
| candidates = [s for s in similarities if abs(s['similarity'] - target_percentage) <= target_range] | |
| if not candidates: | |
| closest = min(similarities, key=lambda x: abs(x['similarity'] - target_percentage)) | |
| actual_percentage = closest['similarity'] | |
| candidates = [s for s in similarities if abs(s['similarity'] - actual_percentage) <= 1.0] | |
| print(f"⚠ No sequences at {target_percentage}%. Using closest: {actual_percentage:.1f}%") | |
| else: | |
| actual_percentage = target_percentage | |
| max_results = 50 | |
| if len(candidates) > max_results: | |
| candidates = candidates[:max_results] | |
| print(f"⚠ Limited to top {max_results} matches") | |
| self.similarity_scores = {c['id']: c['similarity'] for c in candidates} | |
| matched_ids = [c['id'] for c in candidates] | |
| if similarities: | |
| max_sim = max(s['similarity'] for s in similarities) | |
| min_sim = min(s['similarity'] for s in similarities) | |
| avg_sim = sum(s['similarity'] for s in similarities) / len(similarities) | |
| print(f"✓ Found {len(matched_ids)} sequences at ~{actual_percentage:.1f}% similarity, " | |
| f"Range: {min_sim:.1f}% - {max_sim:.1f}% (avg: {avg_sim:.1f}%)") | |
| return matched_ids, actual_percentage | |
| except Exception as e: | |
| print(f"Error finding similar sequences: {e}") | |
| return [], target_percentage | |
| # --- Tree Construction --- | |
| def build_tree_structure(self, matched_ids: List[str]) -> Dict: | |
| """Builds a hierarchical tree structure based on ML groups and genotypes.""" | |
| try: | |
| print("🌳 Building normalized tree structure...") | |
| tree_structure = { | |
| 'root': {'name': 'Root', 'type': 'root', 'children': {}, 'x': 0, 'y': 0, | |
| 'has_vertical_attachment': False, 'extension_level': 0} | |
| } | |
| ml_groups = {} | |
| for idx, row in self.data.iterrows(): | |
| ml_group = row['ML'] | |
| genotype = row['Genotype'] | |
| seq_id = row['Accession Number'] | |
| if ml_group not in ml_groups: | |
| ml_groups[ml_group] = {} | |
| if genotype not in ml_groups[ml_group]: | |
| ml_groups[ml_group][genotype] = [] | |
| ml_groups[ml_group][genotype].append({ | |
| 'id': seq_id, 'data': row.to_dict(), 'is_query': seq_id == self.query_id, | |
| 'is_matched': seq_id in matched_ids, 'similarity': self.similarity_scores.get(seq_id, 0.0) | |
| }) | |
| if self.query_id.startswith("QUERY_"): | |
| predicted_ml = self.predict_ml_group(self.query_sequence) | |
| predicted_genotype = self.predict_genotype(self.query_sequence) | |
| if predicted_ml not in ml_groups: | |
| ml_groups[predicted_ml] = {} | |
| if predicted_genotype not in ml_groups[predicted_ml]: | |
| ml_groups[predicted_ml][predicted_genotype] = [] | |
| ml_groups[predicted_ml][predicted_genotype].append({ | |
| 'id': self.query_id, 'data': { | |
| 'F-gene': self.query_sequence, 'ML': predicted_ml, 'Genotype': predicted_genotype, | |
| 'Accession Number': self.query_id | |
| }, 'is_query': True, 'is_matched': False, 'similarity': 100.0 | |
| }) | |
| normalized_ml_groups = self._normalize_ml_groups(ml_groups) | |
| self._build_normalized_ml_nodes(tree_structure, normalized_ml_groups, matched_ids) | |
| self.tree_structure = tree_structure | |
| print("✓ Tree structure built") | |
| return tree_structure | |
| except Exception as e: | |
| print(f"Error building tree structure: {e}") | |
| return {} | |
| def build_tree_structure_with_ml_safe(self, matched_ids: List[str]) -> Dict: | |
| """Enhances tree structure with ML analysis.""" | |
| try: | |
| print("🌳 Building ML-enhanced tree structure...") | |
| ml_results = self.perform_ml_analysis_safe(matched_ids) | |
| tree_structure = self.build_tree_structure(matched_ids) | |
| if ml_results and 'tree' in ml_results: | |
| tree_structure['ml_analysis'] = { | |
| 'log_likelihood': ml_results['log_likelihood'], | |
| 'sequence_count': ml_results['sequence_count'], | |
| 'alignment_length': ml_results['alignment_length'], | |
| 'ml_tree_available': True | |
| } | |
| self.ml_tree = ml_results['tree'] | |
| self.ml_alignment = ml_results.get('alignment') | |
| print("✓ Tree enhanced with ML analysis") | |
| else: | |
| tree_structure['ml_analysis'] = {'ml_tree_available': False, 'error': 'ML analysis failed'} | |
| print("⚠ ML analysis failed, using standard tree") | |
| return tree_structure | |
| except Exception as e: | |
| print(f"Error building ML-enhanced tree: {e}") | |
| try: | |
| return self.build_tree_structure(matched_ids) | |
| except Exception as e2: | |
| print(f"Fallback failed: {e2}") | |
| return {'error': 'Tree construction failed'} | |
| def _normalize_ml_groups(self, ml_groups: Dict) -> Dict: | |
| """Normalizes ML group names for hierarchical organization.""" | |
| try: | |
| normalized_groups = {} | |
| for ml_name, genotypes in ml_groups.items(): | |
| base_ml = 'UNCL' if ml_name.startswith('UNCL') else ml_name.split('.')[0] if '.' in ml_name and any(c.isdigit() for c in ml_name) else ml_name | |
| if base_ml not in normalized_groups: | |
| normalized_groups[base_ml] = {'full_ml_groups': {}, 'representative_sequences': [], 'has_special_sequences': False} | |
| has_special = any(any(seq['is_query'] or seq['is_matched'] for seq in seqs) for seqs in genotypes.values()) | |
| if has_special: | |
| normalized_groups[base_ml]['has_special_sequences'] = True | |
| normalized_groups[base_ml]['full_ml_groups'][ml_name] = genotypes | |
| elif len(normalized_groups[base_ml]['representative_sequences']) < 2: | |
| for genotype, sequences in list(genotypes.items())[:2]: | |
| if len(normalized_groups[base_ml]['representative_sequences']) < 2: | |
| normalized_groups[base_ml]['representative_sequences'].extend(sequences[:1]) | |
| return normalized_groups | |
| except Exception as e: | |
| print(f"Error normalizing ML groups: {e}") | |
| return {} | |
| def _build_normalized_ml_nodes(self, tree_structure: Dict, normalized_ml_groups: Dict, matched_ids: List[str]): | |
| """Builds normalized ML nodes with equal spacing.""" | |
| try: | |
| self.horizontal_line_tracker = [] | |
| self._identify_query_ml_group(normalized_ml_groups) | |
| ml_positions = self._calculate_dynamic_ml_positions(normalized_ml_groups) | |
| tree_structure['root']['has_vertical_attachment'] = len(normalized_ml_groups) > 1 | |
| for ml_idx, (base_ml, ml_data) in enumerate(normalized_ml_groups.items()): | |
| y_pos = ml_positions[ml_idx] | |
| has_vertical = ml_data['has_special_sequences'] and len(ml_data['full_ml_groups']) > 1 | |
| contains_query = base_ml == self.query_ml_group | |
| horizontal_length = self._determine_horizontal_line_length('normalized_ml_group', has_vertical, contains_query) | |
| x_pos = horizontal_length | |
| tree_structure['root']['children'][base_ml] = { | |
| 'name': base_ml, 'type': 'normalized_ml_group', 'children': {}, 'x': x_pos, 'y': y_pos, | |
| 'has_special_sequences': ml_data['has_special_sequences'], 'has_vertical_attachment': has_vertical, | |
| 'horizontal_line_length': horizontal_length, 'contains_query': contains_query | |
| } | |
| if ml_data['has_special_sequences']: | |
| self._build_full_ml_nodes(tree_structure['root']['children'][base_ml], ml_data['full_ml_groups'], | |
| y_pos, matched_ids, x_pos) | |
| else: | |
| self._add_representative_sequences(tree_structure['root']['children'][base_ml], | |
| ml_data['representative_sequences'], y_pos, x_pos) | |
| except Exception as e: | |
| print(f"Error building normalized ML nodes: {e}") | |
| def _build_full_ml_nodes(self, normalized_ml_node: Dict, full_ml_groups: Dict, base_y: float, matched_ids: List[str], parent_x: float): | |
| """Builds full ML nodes with genotypes.""" | |
| try: | |
| full_ml_positions = self._calculate_full_ml_positions(full_ml_groups, base_y) | |
| for ml_idx, (full_ml_name, genotypes) in enumerate(full_ml_groups.items()): | |
| y_pos = full_ml_positions[ml_idx] | |
| special_genotypes_count = sum(1 for g, seqs in genotypes.items() if any(s['is_query'] or s['is_matched'] for s in seqs)) | |
| has_vertical = special_genotypes_count > 1 | |
| contains_query = any(any(seq['is_query'] for seq in seqs) for seqs in genotypes.values()) | |
| horizontal_length = self._determine_horizontal_line_length('full_ml_group', has_vertical, contains_query) | |
| x_pos = parent_x + horizontal_length | |
| normalized_ml_node['children'][full_ml_name] = { | |
| 'name': full_ml_name, 'type': 'full_ml_group', 'children': {}, 'x': x_pos, 'y': y_pos, | |
| 'sequences_count': sum(len(seqs) for seqs in genotypes.values()), 'has_vertical_attachment': has_vertical, | |
| 'horizontal_line_length': horizontal_length, 'contains_query': contains_query | |
| } | |
| self._build_genotype_nodes(normalized_ml_node['children'][full_ml_name], genotypes, y_pos, matched_ids, x_pos) | |
| except Exception as e: | |
| print(f"Error building full ML nodes: {e}") | |
| def _build_genotype_nodes(self, full_ml_node: Dict, genotypes: Dict, base_y: float, matched_ids: List[str], parent_x: float): | |
| """Builds genotype nodes with sequences.""" | |
| try: | |
| special_genotypes = [(g, seqs) for g, seqs in genotypes.items() if any(s['is_query'] or s['is_matched'] for s in seqs)] | |
| if not special_genotypes: | |
| return | |
| genotype_positions = self._calculate_genotype_positions(special_genotypes, base_y) | |
| genotype_sequence_counts = [(g, seqs, len([s for s in seqs if s['is_query'] or s['is_matched']])) for g, seqs in special_genotypes] | |
| for gt_idx, (genotype, sequences, sequence_count) in enumerate(genotype_sequence_counts): | |
| y_pos = genotype_positions[gt_idx] | |
| special_sequences = [s for s in sequences if s['is_query'] or s['is_matched']] | |
| has_vertical = len(special_sequences) > 1 | |
| contains_query = any(s['is_query'] for s in sequences) | |
| horizontal_length = self._determine_genotype_horizontal_line_length(sequence_count, has_vertical, contains_query) | |
| x_pos = parent_x + horizontal_length | |
| full_ml_node['children'][genotype] = { | |
| 'name': genotype, 'type': 'genotype', 'children': {}, 'x': x_pos, 'y': y_pos, | |
| 'sequences': sequences, 'has_vertical_attachment': has_vertical, | |
| 'horizontal_line_length': horizontal_length, 'contains_query': contains_query, | |
| 'sequence_count': sequence_count | |
| } | |
| self._add_sequences_horizontal(full_ml_node['children'][genotype], sequences, y_pos, x_pos) | |
| except Exception as e: | |
| print(f"Error building genotype nodes: {e}") | |
| def _add_representative_sequences(self, normalized_ml_node: Dict, representative_sequences: List[Dict], base_y: float, parent_x: float): | |
| """Adds representative sequences to normalized ML nodes.""" | |
| try: | |
| if not representative_sequences: | |
| return | |
| has_vertical = len(representative_sequences) > 1 | |
| horizontal_length = self._determine_horizontal_line_length('representative', has_vertical) | |
| x_pos = parent_x + horizontal_length | |
| if len(representative_sequences) == 1: | |
| seq = representative_sequences[0] | |
| normalized_ml_node['children'][f"{seq['id']}_rep"] = { | |
| 'name': f"{seq['id']} (Rep)", 'type': 'representative_sequence', 'data': seq, | |
| 'x': x_pos, 'y': base_y, 'has_vertical_attachment': False, 'horizontal_line_length': horizontal_length | |
| } | |
| else: | |
| positions = self._calculate_sequence_positions(representative_sequences, base_y) | |
| for idx, seq in enumerate(representative_sequences): | |
| normalized_ml_node['children'][f"{seq['id']}_rep"] = { | |
| 'name': f"{seq['id']} (Rep)", 'type': 'representative_sequence', 'data': seq, | |
| 'x': x_pos, 'y': positions[idx], 'has_vertical_attachment': False, 'horizontal_line_length': horizontal_length | |
| } | |
| except Exception as e: | |
| print(f"Error adding representative sequences: {e}") | |
| def _add_sequences_horizontal(self, genotype_node: Dict, sequences: List[Dict], base_y: float, parent_x: float): | |
| """Adds sequences with similarity-based line lengths.""" | |
| try: | |
| query_line_length = 3.0 | |
| query_sequences = [s for s in sequences if s['is_query']] | |
| matched_sequences = [s for s in sequences if s['is_matched'] and not s['is_query']] | |
| all_special_sequences = query_sequences + matched_sequences | |
| if len(all_special_sequences) == 1: | |
| sequence = all_special_sequences[0] | |
| line_length = self._calculate_similarity_based_line_length(sequence, query_line_length) | |
| x_pos = parent_x + line_length | |
| genotype_node['children'][sequence['id']] = { | |
| 'name': f"{sequence['id']} ({sequence['similarity']}%)" if sequence['is_matched'] else sequence['id'], | |
| 'type': 'sequence', 'data': sequence, 'x': x_pos, 'y': base_y, | |
| 'has_vertical_attachment': False, 'similarity_line_length': line_length | |
| } | |
| else: | |
| sequence_positions = self._calculate_sequence_positions(all_special_sequences, base_y) | |
| for seq_idx, sequence in enumerate(all_special_sequences): | |
| line_length = self._calculate_similarity_based_line_length(sequence, query_line_length) | |
| x_pos = parent_x + line_length | |
| genotype_node['children'][sequence['id']] = { | |
| 'name': f"{sequence['id']} ({sequence['similarity']}%)" if sequence['is_matched'] else sequence['id'], | |
| 'type': 'sequence', 'data': sequence, 'x': x_pos, 'y': sequence_positions[seq_idx], | |
| 'has_vertical_attachment': False, 'similarity_line_length': line_length | |
| } | |
| except Exception as e: | |
| print(f"Error adding sequences: {e}") | |
| def _identify_query_ml_group(self, normalized_ml_groups: Dict): | |
| """Identifies the ML group containing the query sequence.""" | |
| try: | |
| for base_ml, ml_data in normalized_ml_groups.items(): | |
| if ml_data['has_special_sequences']: | |
| for genotypes in ml_data['full_ml_groups'].values(): | |
| for sequences in genotypes.values(): | |
| if any(seq['is_query'] for seq in sequences): | |
| self.query_ml_group = base_ml | |
| return | |
| except Exception as e: | |
| print(f"Error identifying query ML group: {e}") | |
| def _calculate_dynamic_ml_positions(self, normalized_ml_groups: Dict) -> List[float]: | |
| """Calculates equal Y positions for ML groups.""" | |
| try: | |
| ml_count = len(normalized_ml_groups) | |
| if ml_count == 0: | |
| return [] | |
| if ml_count == 1: | |
| return [0.0] | |
| total_spacing = (ml_count - 1) * 2.0 | |
| start_y = -total_spacing / 2 | |
| return [start_y + i * 2.0 for i in range(ml_count)] | |
| except Exception as e: | |
| print(f"Error calculating ML positions: {e}") | |
| return list(range(len(normalized_ml_groups))) | |
| def _calculate_full_ml_positions(self, full_ml_groups: Dict, base_y: float) -> List[float]: | |
| """Calculates equal positions for full ML groups.""" | |
| try: | |
| ml_count = len(full_ml_groups) | |
| if ml_count <= 1: | |
| return [base_y] | |
| spacing = 1.5 | |
| start_y = base_y - (spacing * (ml_count - 1)) / 2 | |
| return [start_y + i * spacing for i in range(ml_count)] | |
| except Exception as e: | |
| print(f"Error calculating full ML positions: {e}") | |
| return [base_y] * len(full_ml_groups) | |
| def _calculate_genotype_positions(self, special_genotypes: List, base_y: float) -> List[float]: | |
| """Calculates equal positions for genotypes.""" | |
| try: | |
| genotype_count = len(special_genotypes) | |
| if genotype_count <= 1: | |
| return [base_y] | |
| spacing = 1.0 | |
| start_y = base_y - (spacing * (genotype_count - 1)) / 2 | |
| return [start_y + i * spacing for i in range(genotype_count)] | |
| except Exception as e: | |
| print(f"Error calculating genotype positions: {e}") | |
| return [base_y] * len(special_genotypes) | |
| def _calculate_sequence_positions(self, sequences: List[Dict], base_y: float) -> List[float]: | |
| """Calculates equal positions for sequences.""" | |
| try: | |
| seq_count = len(sequences) | |
| if seq_count <= 1: | |
| return [base_y] | |
| spacing = 0.8 | |
| start_y = base_y - (spacing * (seq_count - 1)) / 2 | |
| return [start_y + i * spacing for i in range(seq_count)] | |
| except Exception as e: | |
| print(f"Error calculating sequence positions: {e}") | |
| return [base_y] * len(sequences) | |
| def _calculate_similarity_based_line_length(self, sequence: Dict, query_line_length: float) -> float: | |
| """Calculates line length based on sequence similarity.""" | |
| try: | |
| if sequence['is_query']: | |
| return query_line_length | |
| if sequence['is_matched']: | |
| similarity = sequence['similarity'] | |
| proportional_length = (similarity / 100.0) * query_line_length | |
| return max(proportional_length, query_line_length * 0.2) | |
| return query_line_length * 0.5 | |
| except Exception as e: | |
| print(f"Error calculating line length: {e}") | |
| return query_line_length * 0.5 | |
| def _determine_horizontal_line_length(self, node_type: str, has_vertical: bool, contains_query: bool = False) -> float: | |
| """Determines horizontal line length based on node type.""" | |
| try: | |
| base_length = self.base_horizontal_length | |
| if contains_query and node_type == 'normalized_ml_group': | |
| return base_length * 2.5 | |
| if has_vertical: | |
| current_max = base_length | |
| for length in self.horizontal_line_tracker: | |
| if length > current_max: | |
| current_max = length | |
| new_length = current_max + 0.3 | |
| self.horizontal_line_tracker.append(new_length) | |
| return new_length | |
| return base_length | |
| except Exception as e: | |
| print(f"Error determining line length: {e}") | |
| return self.base_horizontal_length | |
| def _determine_genotype_horizontal_line_length(self, sequence_count: int, has_vertical: bool, contains_query: bool = False) -> float: | |
| """Determines horizontal line length for genotype nodes.""" | |
| try: | |
| base_length = self.base_horizontal_length | |
| query_bonus = 0.5 if contains_query else 0.0 | |
| if sequence_count <= 1: | |
| length_multiplier = 1.0 | |
| elif sequence_count <= 3: | |
| length_multiplier = 1.6 | |
| elif sequence_count <= 5: | |
| length_multiplier = 2.3 | |
| else: | |
| length_multiplier = 6.0 | |
| return base_length * length_multiplier + query_bonus | |
| except Exception as e: | |
| print(f"Error determining genotype line length: {e}") | |
| return self.base_horizontal_length | |
| # --- Visualization --- | |
| def create_interactive_tree(self, matched_ids: List[str], actual_percentage: float) -> Optional[go.Figure]: | |
| """Creates an interactive horizontal phylogenetic tree visualization.""" | |
| try: | |
| print("🎨 Creating interactive tree visualization...") | |
| edge_x, edge_y = [], [] | |
| node_x, node_y = [], [] | |
| node_colors, node_text, node_hover, node_sizes = [], [], [], [] | |
| colors = { | |
| 'root': '#FF0000', 'normalized_ml_group': '#FFB6C1', 'full_ml_group': '#FF69B4', | |
| 'genotype': '#FFD700', 'representative_sequence': '#FFA500', 'query_sequence': '#4B0082', | |
| 'matched_sequence': '#6A5ACD', 'other_sequence': '#87CEEB' | |
| } | |
| def add_horizontal_edges(parent_x, parent_y, children_dict): | |
| if not children_dict: | |
| return | |
| children_list = list(children_dict.values()) | |
| if len(children_list) == 1: | |
| child = children_list[0] | |
| edge_x.extend([parent_x, child['x'], None]) | |
| edge_y.extend([parent_y, child['y'], None]) | |
| else: | |
| child_x_positions = [child['x'] for child in children_list] | |
| min_child_x = min(child_x_positions) | |
| intermediate_x = parent_x + (min_child_x - parent_x) * 0.8 | |
| edge_x.extend([parent_x, intermediate_x, None]) | |
| edge_y.extend([parent_y, parent_y, None]) | |
| child_y_positions = [child['y'] for child in children_list] | |
| min_y, max_y = min(child_y_positions), max(child_y_positions) | |
| edge_x.extend([intermediate_x, intermediate_x, None]) | |
| edge_y.extend([min_y, max_y, None]) | |
| for child in children_list: | |
| edge_x.extend([intermediate_x, child['x'], None]) | |
| edge_y.extend([child['y'], child['y'], None]) | |
| def get_node_color_and_size(node): | |
| if node['type'] == 'sequence': | |
| if node['data']['is_query']: | |
| return colors['query_sequence'], 10 | |
| if node['data']['is_matched']: | |
| return colors['matched_sequence'], 8 | |
| return colors['other_sequence'], 6 | |
| if node['type'] == 'representative_sequence': | |
| return colors['representative_sequence'], 7 | |
| if node['type'] == 'normalized_ml_group': | |
| return colors['normalized_ml_group'], 9 if node.get('has_special_sequences', False) else 7 | |
| if node['type'] == 'full_ml_group': | |
| return colors['full_ml_group'], 8 | |
| if node['type'] == 'genotype': | |
| return colors['genotype'], 7 | |
| return colors.get(node['type'], '#000000'), 7 | |
| def create_node_text(node): | |
| if node['type'] == 'sequence': | |
| return f"{node['name']}" if node['data']['is_matched'] and not node['data']['is_query'] else node['name'] | |
| if node['type'] == 'representative_sequence': | |
| return node['name'] | |
| if node['type'] == 'normalized_ml_group': | |
| return f"{node['name']} *" if node.get('has_special_sequences', False) else node['name'] | |
| return node['name'] | |
| def create_hover_text(node): | |
| if node['type'] == 'sequence': | |
| data = node['data']['data'] | |
| hover_text = ( | |
| f"<b>{node['name']}</b><br>Type: {'Query' if node['data']['is_query'] else 'Matched' if node['data']['is_matched'] else 'Other'} Sequence<br>" | |
| f"ML Group: {data.get('ML', 'N/A')}<br>Genotype: {data.get('Genotype', 'N/A')}<br>" | |
| f"Host: {data.get('Host', 'N/A')}<br>Country: {data.get('Country', 'N/A')}<br>" | |
| f"Isolate: {data.get('Isolate', 'N/A')}<br>Year: {data.get('Year', 'N/A')}" | |
| ) | |
| if node['data']['is_matched']: | |
| hover_text += f"<br><b>Similarity: {node['data']['similarity']}%</b>" | |
| elif node['type'] == 'representative_sequence': | |
| data = node['data']['data'] | |
| hover_text = ( | |
| f"<b>{node['name']}</b><br>Type: Representative Sequence<br>" | |
| f"ML Group: {data.get('ML', 'N/A')}<br>Genotype: {data.get('Genotype', 'N/A')}<br>" | |
| f"Host: {data.get('Host', 'N/A')}<br>Country: {data.get('Country', 'N/A')}" | |
| ) | |
| elif node['type'] == 'normalized_ml_group': | |
| hover_text = f"<b>{node['name']}</b><br>Type: Normalized ML Group" | |
| if node.get('has_special_sequences', False): | |
| hover_text += "<br>Contains query/matched sequences" | |
| else: | |
| hover_text += "<br>Representative sequences only" | |
| elif node['type'] == 'full_ml_group': | |
| hover_text = f"<b>{node['name']}</b><br>Type: Full ML Group" | |
| if 'sequences_count' in node: | |
| hover_text += f"<br>Total Sequences: {node['sequences_count']}" | |
| elif node['type'] == 'genotype': | |
| hover_text = f"<b>{node['name']}</b><br>Type: Genotype" | |
| if 'sequences' in node: | |
| special_count = sum(1 for seq in node['sequences'] if seq['is_query'] or seq['is_matched']) | |
| hover_text += f"<br>Special Sequences: {special_count}/{len(node['sequences'])}" | |
| else: | |
| hover_text = f"<b>{node['name']}</b><br>Type: {node['type'].replace('_', ' ').title()}" | |
| return hover_text | |
| def add_node_and_edges(node, parent_x=None, parent_y=None): | |
| x, y = node['x'], node['y'] | |
| node_x.append(x) | |
| node_y.append(y) | |
| color, size = get_node_color_and_size(node) | |
| node_colors.append(color) | |
| node_sizes.append(size) | |
| node_text.append(create_node_text(node)) | |
| node_hover.append(create_hover_text(node)) | |
| if 'children' in node and node['children']: | |
| add_horizontal_edges(x, y, node['children']) | |
| for child in node['children'].values(): | |
| add_node_and_edges(child, x, y) | |
| root_node = self.tree_structure['root'] | |
| add_node_and_edges(root_node) | |
| if root_node['children']: | |
| add_horizontal_edges(root_node['x'], root_node['y'], root_node['children']) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=edge_x, y=edge_y, mode='lines', line=dict(width=1, color='gray'), | |
| hoverinfo='none', showlegend=False | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=node_x, y=node_y, mode='markers+text', | |
| marker=dict(size=node_sizes, color=node_colors, line=dict(width=1, color='black'), opacity=0.85), | |
| text=node_text, textposition="middle right", textfont=dict(size=9, color="black"), | |
| hoverinfo='text', hovertext=node_hover, showlegend=False | |
| )) | |
| min_x, max_x = min(node_x), max(node_x) if node_x else (0, 1) | |
| min_y, max_y = min(node_y), max(node_y) if node_y else (0, 1) | |
| x_range = max_x - min_x | |
| y_range = max_y - min_y | |
| x_padding = x_range * 0.2 if x_range > 0 else 1 | |
| y_padding = y_range * 0.2 if y_range > 0 else 1 | |
| width = min(1400, max(800, int(x_range * 80 + 400))) | |
| height = min(900, max(500, int(y_range * 40 + 300))) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"Horizontal Phylogenetic Tree<br>Query: {self.query_id} | Similarity: {actual_percentage}% | Matched: {len(matched_ids)}", | |
| x=0.5, font=dict(size=12) | |
| ), | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[min_x - x_padding, max_x + x_padding], automargin=True), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[min_y - y_padding, max_y + y_padding], automargin=True), | |
| plot_bgcolor="white", paper_bgcolor="white", hovermode="closest", | |
| width=width, height=height, margin=dict(l=20, r=100, t=40, b=10), | |
| showlegend=True, legend=dict(x=1.02, y=1, xanchor='left', yanchor='top', | |
| bgcolor='rgba(255,255,255,0.8)', bordercolor='gray', borderwidth=1, font=dict(size=10)) | |
| ) | |
| legend_elements = [ | |
| dict(name="Root", marker=dict(color=colors['root'], size=8)), | |
| dict(name="Normalized ML Groups", marker=dict(color=colors['normalized_ml_group'], size=8)), | |
| dict(name="Full ML Groups", marker=dict(color=colors['full_ml_group'], size=8)), | |
| dict(name="Genotypes", marker=dict(color=colors['genotype'], size=8)), | |
| dict(name="Query Sequence", marker=dict(color=colors['query_sequence'], size=10)), | |
| dict(name="Similar Sequences", marker=dict(color=colors['matched_sequence'], size=9)), | |
| dict(name="Representative Sequences", marker=dict(color=colors['representative_sequence'], size=8)), | |
| ] | |
| for element in legend_elements: | |
| fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=element['marker'], name=element['name'], showlegend=True)) | |
| config = { | |
| 'displayModeBar': True, 'displaylogo': False, 'modeBarButtonsToRemove': ['select2d', 'lasso2d'], | |
| 'toImageButtonOptions': {'format': 'png', 'filename': 'phylogenetic_tree', 'height': height, 'width': width, 'scale': 2} | |
| } | |
| try: | |
| fig.show(config) | |
| except Exception as e: | |
| print(f"Warning: Could not display figure: {e}") | |
| return fig | |
| except Exception as e: | |
| print(f"Error creating tree visualization: {e}") | |
| return None | |
| # --- ML Analysis --- | |
| def perform_ml_analysis_safe(self, matched_ids: List[str]) -> Dict: | |
| try: | |
| print("\n🧬 PERFORMING MAXIMUM LIKELIHOOD ANALYSIS") | |
| print("="*50) | |
| # Include query sequence in analysis | |
| all_sequences = [self.query_id] + [seq_id for seq_id in matched_ids if seq_id != self.query_id] | |
| # Limit number of sequences to prevent memory issues | |
| if len(all_sequences) > 20: | |
| print(f"Warning: Limiting analysis to 20 sequences (had {len(all_sequences)})") | |
| all_sequences = all_sequences[:20] | |
| if len(all_sequences) < 3: | |
| print("❌ Need at least 3 sequences for ML analysis") | |
| return {} | |
| # Step 1: Create multiple sequence alignment | |
| alignment = self.create_sequence_alignment(all_sequences) | |
| if not alignment: | |
| return {} | |
| # Step 2: Calculate ML distances | |
| distance_matrix = self.calculate_ml_distances(alignment) | |
| if distance_matrix.size == 0: | |
| return {} | |
| # Step 3: Construct ML tree | |
| ml_tree = self.construct_ml_tree(alignment) | |
| if not ml_tree: | |
| return {} | |
| # Step 4: Calculate tree likelihood (safely) | |
| log_likelihood = self.calculate_ml_likelihood_safe(ml_tree, alignment) | |
| # Step 5: Prepare results | |
| ml_results = { | |
| 'tree': ml_tree, | |
| 'alignment': alignment, | |
| 'distance_matrix': distance_matrix, | |
| 'log_likelihood': log_likelihood, | |
| 'sequence_count': len(all_sequences), | |
| 'alignment_length': len(alignment[0]) if alignment else 0 | |
| } | |
| print(f"✅ ML analysis completed successfully") | |
| print(f" Sequences analyzed: {len(all_sequences)}") | |
| print(f" Alignment length: {ml_results['alignment_length']}") | |
| print(f" Log-likelihood: {log_likelihood:.2f}") | |
| return ml_results | |
| except Exception as e: | |
| print(f"❌ ML analysis failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return {} | |
| def create_sequence_alignment(self, sequence_ids: List[str]) -> Optional[MultipleSeqAlignment]: | |
| try: | |
| print("🧬 Creating multiple sequence alignment...") | |
| # Get sequences | |
| sequences = [] | |
| for seq_id in sequence_ids: | |
| try: | |
| row = self.data[self.data['Accession Number'] == seq_id] | |
| if not row.empty: | |
| f_gene = str(row.iloc[0]['F-gene']) | |
| # Clean sequence (remove non-nucleotide characters) | |
| clean_seq = re.sub(r'[^ATGCN-]', '', f_gene.upper()) | |
| if len(clean_seq) > 10: # Minimum sequence length | |
| seq_record = SeqRecord(Seq(clean_seq), id=seq_id, description="") | |
| sequences.append(seq_record) | |
| except Exception as e: | |
| print(f"Warning: Skipping sequence {seq_id}: {e}") | |
| continue | |
| if len(sequences) < 2: | |
| print("❌ Need at least 2 valid sequences for alignment") | |
| return None | |
| # Simple alignment (you might want to use MUSCLE or CLUSTAL for better results) | |
| aligned_sequences = self._simple_alignment(sequences) | |
| print(f"✓ Alignment created with {len(aligned_sequences)} sequences") | |
| return MultipleSeqAlignment(aligned_sequences) | |
| except Exception as e: | |
| print(f"Error creating alignment: {e}") | |
| return None | |
| def _simple_alignment(self, sequences: List[SeqRecord]) -> List[SeqRecord]: | |
| try: | |
| # Find maximum length | |
| max_length = max(len(seq.seq) for seq in sequences) | |
| # Cap maximum length to prevent memory issues | |
| if max_length > 10000: | |
| max_length = 10000 | |
| print(f"Warning: Sequences truncated to {max_length} bp") | |
| # Pad sequences to same length | |
| aligned_sequences = [] | |
| for seq in sequences: | |
| seq_str = str(seq.seq)[:max_length] # Truncate if too long | |
| if len(seq_str) < max_length: | |
| # Pad with gaps at the end | |
| padded_seq = seq_str + '-' * (max_length - len(seq_str)) | |
| else: | |
| padded_seq = seq_str | |
| aligned_sequences.append(SeqRecord(Seq(padded_seq), id=seq.id, description=seq.description)) | |
| return aligned_sequences | |
| except Exception as e: | |
| print(f"Error in simple alignment: {e}") | |
| return sequences | |
| def calculate_ml_distances(self, alignment: MultipleSeqAlignment) -> np.ndarray: | |
| try: | |
| print("📊 Calculating ML distances...") | |
| # Convert alignment to numeric matrix | |
| seq_matrix = self._alignment_to_matrix(alignment) | |
| n_sequences = len(alignment) | |
| if n_sequences == 0: | |
| return np.array([]) | |
| # Initialize distance matrix | |
| distance_matrix = np.zeros((n_sequences, n_sequences)) | |
| # Calculate pairwise ML distances | |
| for i in range(n_sequences): | |
| for j in range(i + 1, n_sequences): | |
| try: | |
| ml_distance = self._calculate_ml_distance_pair(seq_matrix[i], seq_matrix[j]) | |
| distance_matrix[i][j] = ml_distance | |
| distance_matrix[j][i] = ml_distance | |
| except Exception as e: | |
| print(f"Warning: Error calculating distance between sequences {i} and {j}: {e}") | |
| # Use maximum distance as fallback | |
| distance_matrix[i][j] = 1.0 | |
| distance_matrix[j][i] = 1.0 | |
| print("✓ ML distances calculated") | |
| return distance_matrix | |
| except Exception as e: | |
| print(f"Error calculating ML distances: {e}") | |
| return np.array([]) | |
| def _alignment_to_matrix(self, alignment: MultipleSeqAlignment) -> np.ndarray: | |
| try: | |
| nucleotide_map = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4, '-': 5} | |
| matrix = [] | |
| for record in alignment: | |
| sequence = str(record.seq).upper() | |
| numeric_seq = [nucleotide_map.get(nuc, 4) for nuc in sequence] | |
| matrix.append(numeric_seq) | |
| return np.array(matrix) | |
| except Exception as e: | |
| print(f"Error converting alignment to matrix: {e}") | |
| return np.array([]) | |
| def _calculate_ml_distance_pair(self, seq1: np.ndarray, seq2: np.ndarray) -> float: | |
| try: | |
| if len(seq1) == 0 or len(seq2) == 0: | |
| return 1.0 | |
| # Count differences (excluding gaps and N's) | |
| valid_positions = (seq1 < 4) & (seq2 < 4) # Exclude N's and gaps | |
| if np.sum(valid_positions) == 0: | |
| return 1.0 # Maximum distance if no valid comparisons | |
| differences = np.sum(seq1[valid_positions] != seq2[valid_positions]) | |
| total_valid = np.sum(valid_positions) | |
| if total_valid == 0: | |
| return 1.0 | |
| # Calculate proportion of differences | |
| p = differences / total_valid | |
| # Jukes-Cantor correction | |
| if p >= 0.75: | |
| return 1.0 # Maximum distance | |
| # JC distance formula: -3/4 * ln(1 - 4p/3) | |
| try: | |
| jc_distance = -0.75 * np.log(1 - (4 * p / 3)) | |
| return min(max(jc_distance, 0.0), 1.0) # Clamp between 0 and 1 | |
| except (ValueError, RuntimeWarning): | |
| return 1.0 # Return maximum distance if log calculation fails | |
| except Exception as e: | |
| return 1.0 | |
| def construct_ml_tree(self, alignment: MultipleSeqAlignment) -> Optional[Tree]: | |
| """Constructs a maximum likelihood tree.""" | |
| try: | |
| print("🌳 Constructing ML tree...") | |
| distance_matrix = self.calculate_ml_distances(alignment) | |
| if distance_matrix.size == 0: | |
| return None | |
| sequence_names = [record.id for record in alignment] | |
| tree = self._build_nj_tree_from_distances(distance_matrix, sequence_names) | |
| if tree: | |
| tree = self._optimize_branch_lengths_ml_safe(tree, alignment) | |
| print("✓ ML tree constructed") | |
| return tree | |
| except Exception as e: | |
| print(f"Error constructing ML tree: {e}") | |
| return None | |
| def _build_nj_tree_from_distances(self, distance_matrix: np.ndarray, sequence_names: List[str]) -> Optional[Tree]: | |
| """Builds a neighbor-joining tree from distance matrix.""" | |
| try: | |
| if distance_matrix.shape[0] != len(sequence_names): | |
| print("Error: Distance matrix size mismatch") | |
| return None | |
| matrix_data = [[0.0 if i == j else max(0.0, float(distance_matrix[i][j])) for j in range(i + 1)] for i in range(len(sequence_names))] | |
| dm = DistanceMatrix(names=sequence_names, matrix=matrix_data) | |
| constructor = DistanceTreeConstructor() | |
| tree = constructor.nj(dm) | |
| return tree if self._validate_tree_structure(tree) else None | |
| except Exception as e: | |
| print(f"Error building NJ tree: {e}") | |
| return None | |
| def _validate_tree_structure(self, tree: Tree, max_depth: int = 100) -> bool: | |
| """Validates tree structure to prevent recursion issues.""" | |
| try: | |
| visited = set() | |
| def check_node(node, depth=0): | |
| if depth > max_depth: | |
| return False | |
| node_id = id(node) | |
| if node_id in visited: | |
| return False | |
| visited.add(node_id) | |
| return all(check_node(child, depth + 1) for child in getattr(node, 'clades', [])) | |
| return check_node(tree.root if hasattr(tree, 'root') else tree) | |
| except Exception: | |
| return False | |
| def _optimize_branch_lengths_ml_safe(self, tree: Tree, alignment: MultipleSeqAlignment) -> Tree: | |
| """Optimizes branch lengths using ML model.""" | |
| try: | |
| print("🔧 Optimizing branch lengths...") | |
| old_limit = sys.getrecursionlimit() | |
| sys.setrecursionlimit(1000) | |
| try: | |
| seq_matrix = self._alignment_to_matrix(alignment) | |
| if seq_matrix.size == 0: | |
| return tree | |
| all_clades = self._get_clades_safe(tree) | |
| for clade in all_clades: | |
| if hasattr(clade, 'branch_length') and clade.branch_length is not None: | |
| optimal_length = self._calculate_optimal_branch_length(clade, seq_matrix) | |
| clade.branch_length = max(optimal_length, 0.001) | |
| finally: | |
| sys.setrecursionlimit(old_limit) | |
| print("✓ Branch lengths optimized") | |
| return tree | |
| except Exception as e: | |
| print(f"Warning: Branch optimization failed: {e}") | |
| return tree | |
| def _get_clades_safe(self, tree: Tree, max_depth: int = 50) -> List: | |
| """Safely retrieves all clades in the tree.""" | |
| clades = [] | |
| visited = set() | |
| def traverse_node(node, depth=0): | |
| if depth > max_depth or id(node) in visited: | |
| return | |
| visited.add(id(node)) | |
| clades.append(node) | |
| for child in getattr(node, 'clades', []): | |
| traverse_node(child, depth + 1) | |
| try: | |
| traverse_node(tree.root if hasattr(tree, 'root') else tree) | |
| except Exception as e: | |
| print(f"Warning: Tree traversal error: {e}") | |
| return clades | |
| def _calculate_optimal_branch_length(self, clade: float, seq_matrix: np.ndarray) -> float: | |
| """Calculates optimal branch length for a clade.""" | |
| try: | |
| if not hasattr(clade, 'branch_length') or clade.branch_length is None: | |
| return 0.1 | |
| current_length = float(clade.branch_length) | |
| if np.isnan(current_length) or np.isinf(current_length) or current_length <= 0: | |
| return 0.1 | |
| return min(max(current_length * (0.9 if hasattr(clade, 'name') and clade.name else 1.1), 0.001), 1.0) | |
| except Exception: | |
| return 0.1 | |
| def calculate_ml_likelihood_safe(self, tree: Tree, alignment: MultipleSeqAlignment) -> float: | |
| """Calculates tree likelihood using Jukes-Cantor model.""" | |
| try: | |
| print("Trying to calculate tree likelihood...") | |
| seq_matrix = self._alignment_to_matrix(alignment) | |
| if seq_matrix.size == 0: | |
| return -np.inf | |
| total_log_likelihood = 0.0 | |
| n_sites = min(seq_matrix.shape[1], 1000) | |
| for site in range(0, n_sites, max(1, n_sites // 100)): | |
| site_pattern = seq_matrix[:, site] | |
| valid_positions = site_pattern < 4 | |
| if np.sum(valid_positions) < 2: | |
| continue | |
| site_likelihood = self._calculate_site_likelihood_safe(tree, site_pattern) | |
| if site_likelihood > 0: | |
| total_log_likelihood += np.log(site_likelihood) | |
| print(f"Likelihood: {total_log_likelihood:.2f}") | |
| return total_log_likelihood | |
| except Exception as e: | |
| print(f"Error calculating likelihood: {e}") | |
| return -np.inf | |
| def _calculate_site_likelihood_safe(self, tree: np.ndarray, site_pattern: np.ndarray) -> float: | |
| """Calculates likelihood for a single site.""" | |
| try: | |
| valid_nucs = site_pattern[site_pattern < 4] | |
| if len(valid_nucs) == 0: | |
| return 1.0 | |
| unique_nucs = len(np.unique(valid_nucs)) | |
| total_nucs = len(valid_nucs) | |
| diversity_factor = unique_nucs / 4.0 | |
| likelihood = np.exp(-diversity_factor * total_nucs * 0.1) | |
| return max(likelihood, 1e-10) | |
| except Exception: | |
| return 1e-10 | |
| # --- Reporting --- | |
| def generate_detailed_report(self, matched_ids: List[str], actual_percentage: float) -> bool: | |
| """ | |
| Generate a detailed HTML report for virologists/scientists with query details, matched sequences, | |
| model performance, phylogenetic tree insights, and ML analysis results in tabular format. | |
| Outputs a styled HTML file. | |
| Returns True if successful, False otherwise. | |
| """ | |
| try: | |
| print("📝 Generating detailed HTML analysis report...") | |
| # --- HTML Template with Inline CSS --- | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Phylogenetic Analysis Report - {query_id}</title> | |
| <style> | |
| body {{ | |
| font-family: Arial, sans-serif; | |
| margin: 20px; | |
| background-color: #f9f9f9; | |
| color: #333; | |
| }} | |
| h1 {{ | |
| text-align: center; | |
| color: #2c3e50; | |
| }} | |
| h2 {{ | |
| color: #34495e; | |
| margin-top: 20px; | |
| }} | |
| table {{ | |
| width: 100%; | |
| border-collapse: collapse; | |
| margin-bottom: 20px; | |
| background-color: #fff; | |
| box-shadow: 0 2px 5px rgba(0,0,0,0.1); | |
| }} | |
| th, td {{ | |
| padding: 10px; | |
| text-align: left; | |
| border: 1px solid #ddd; | |
| }} | |
| th {{ | |
| background-color: #3498db; | |
| color: #fff; | |
| }} | |
| tr:nth-child(even) {{ | |
| background-color: #f2f2f2; | |
| }} | |
| tr:hover {{ | |
| background-color: #e0f7fa; | |
| }} | |
| .metadata {{ | |
| margin-left: 20px; | |
| font-size: 0.9em; | |
| }} | |
| .metadata p {{ | |
| margin: 5px 0; | |
| }} | |
| @media (max-width: 600px) {{ | |
| table {{ | |
| font-size: 0.85em; | |
| }} | |
| th, td {{ | |
| padding: 8px; | |
| }} | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Phylogenetic Analysis Report</h1> | |
| <p style="text-align: center;">Generated on: {timestamp}</p> | |
| <p style="text-align: center;">Query ID: {query_id}</p> | |
| """ | |
| # Add timestamp and query ID to HTML | |
| timestamp = time.strftime("%Y-%m-%d %H:%M:%S %Z") | |
| html_content = html_content.format(query_id=self.query_id, timestamp=timestamp) | |
| # --- Query Information --- | |
| query_type = ( | |
| "Accession Number" if self.query_id in self.data['Accession Number'].values else | |
| "Dataset Sequence" if self.query_sequence in self.data['F-gene'].values else | |
| "Novel Sequence" | |
| ) | |
| query_ml = "Unknown" | |
| query_genotype = "Unknown" | |
| query_metadata = {} | |
| if query_type == "Novel Sequence": | |
| query_ml = self.predict_ml_group(self.query_sequence) | |
| query_genotype = self.predict_genotype(self.query_sequence) | |
| query_metadata = {"F-gene": self.query_sequence[:50] + "..." if len(self.query_sequence) > 50 else self.query_sequence} | |
| else: | |
| query_row = self.data[ | |
| (self.data['Accession Number'] == self.query_id) | | |
| (self.data['F-gene'] == re.sub(r'[^ATGC]', '', self.query_sequence.upper())) | |
| ].iloc[0] | |
| query_ml = query_row['ML'] | |
| query_genotype = query_row['Genotype'] | |
| query_metadata = query_row.to_dict() | |
| query_metadata['F-gene'] = query_metadata['F-gene'][:50] + "..." if len(query_metadata['F-gene']) > 50 else query_metadata['F-gene'] | |
| query_info_table = [ | |
| ["Query ID", self.query_id], | |
| ["Query Type", query_type], | |
| ["Sequence Length", f"{len(self.query_sequence)} nucleotides"], | |
| ["ML Group", query_ml], | |
| ["Genotype", query_genotype], | |
| ["Target Similarity", f"{self.matching_percentage}%"], | |
| ["Actual Similarity", f"{actual_percentage:.1f}%"] | |
| ] | |
| # Add Query Information section | |
| html_content += """ | |
| <h2>Query Information</h2> | |
| <table> | |
| <tr><th>Field</th><th>Value</th></tr> | |
| """ | |
| for row in query_info_table: | |
| html_content += f""" | |
| <tr><td>{row[0]}</td><td>{row[1]}</td></tr> | |
| """ | |
| html_content += """ | |
| </table> | |
| <div class="metadata"> | |
| <h3>Metadata</h3> | |
| """ | |
| for key, value in query_metadata.items(): | |
| html_content += f""" | |
| <p><strong>{key}:</strong> {value}</p> | |
| """ | |
| html_content += """ | |
| </div> | |
| """ | |
| # --- Matched Sequences --- | |
| matched_sequences_table = [] | |
| headers = ["Accession Number", "Similarity (%)", "ML Group", "Genotype", "Host", "Country", "Isolate", "Year"] | |
| for seq_id in matched_ids: | |
| row = self.data[self.data['Accession Number'] == seq_id].iloc[0] | |
| matched_sequences_table.append([ | |
| seq_id, | |
| f"{self.similarity_scores.get(seq_id, 0.0):.1f}", | |
| row.get('ML', 'N/A'), | |
| row.get('Genotype', 'N/A'), | |
| row.get('Host', 'N/A'), | |
| row.get('Country', 'N/A'), | |
| row.get('Isolate', 'N/A'), | |
| row.get('Year', 'N/A') | |
| ]) | |
| # Add Matched Sequences section | |
| html_content += f""" | |
| <h2>Matched Sequences</h2> | |
| <p>Total Matched Sequences: {len(matched_ids)}</p> | |
| """ | |
| if matched_sequences_table: | |
| html_content += """ | |
| <table> | |
| <tr> | |
| """ | |
| for header in headers: | |
| html_content += f"<th>{header}</th>" | |
| html_content += """ | |
| </tr> | |
| """ | |
| for row in matched_sequences_table: | |
| html_content += "<tr>" | |
| for cell in row: | |
| html_content += f"<td>{cell}</td>" | |
| html_content += "</tr>" | |
| html_content += """ | |
| </table> | |
| """ | |
| else: | |
| html_content += """ | |
| <p>No matched sequences found.</p> | |
| """ | |
| # --- Model Performance --- | |
| model_performance_table = [ | |
| ["ML Model Accuracy", f"{self.ml_model_accuracy:.2%}" if self.ml_model_accuracy else "Not trained"], | |
| ["Genotype Model Accuracy", f"{self.genotype_model_accuracy:.2%}" if self.genotype_model_accuracy else "Not trained"] | |
| ] | |
| # Add Model Performance section | |
| html_content += """ | |
| <h2>Model Performance</h2> | |
| <table> | |
| <tr><th>Metric</th><th>Value</th></tr> | |
| """ | |
| for row in model_performance_table: | |
| html_content += f""" | |
| <tr><td>{row[0]}</td><td>{row[1]}</td></tr> | |
| """ | |
| html_content += """ | |
| </table> | |
| """ | |
| # --- Phylogenetic Tree Insights --- | |
| def count_nodes(node): | |
| count = 1 | |
| for child in node.get('children', {}).values(): | |
| count += count_nodes(child) | |
| return count | |
| total_nodes = count_nodes(self.tree_structure) | |
| query_node_path = [] | |
| def find_query_path(node, path): | |
| if node.get('data', {}).get('is_query', False): | |
| query_node_path.append(" -> ".join(path + [node['name']])) | |
| for name, child in node.get('children', {}).items(): | |
| find_query_path(child, path + [node['name']]) | |
| find_query_path(self.tree_structure['root'], []) | |
| tree_insights_table = [ | |
| ["Total Nodes", total_nodes], | |
| ["ML Groups Represented", len(self.tree_structure['root']['children'])], | |
| ["Query Node Path", query_node_path[0] if query_node_path else "Not found"] | |
| ] | |
| # Add Phylogenetic Tree Insights section | |
| html_content += """ | |
| <h2>Phylogenetic Tree Insights</h2> | |
| <table> | |
| <tr><th>Field</th><th>Value</th></tr> | |
| """ | |
| for row in tree_insights_table: | |
| html_content += f""" | |
| <tr><td>{row[0]}</td><td>{row[1]}</td></tr> | |
| """ | |
| html_content += """ | |
| </table> | |
| """ | |
| # --- ML Analysis Results --- | |
| ml_analysis = self.tree_structure.get('ml_analysis', {}) | |
| ml_analysis_table = [ | |
| ["ML Tree Available", ml_analysis.get('ml_tree_available', False)], | |
| ["Log-Likelihood", f"{ml_analysis.get('log_likelihood', 'N/A'):.2f}" if ml_analysis.get('log_likelihood') else "N/A"], | |
| ["Sequence Count", ml_analysis.get('sequence_count', 'N/A')], | |
| ["Alignment Length", ml_analysis.get('alignment_length', 'N/A')] | |
| ] | |
| # Add ML Analysis Results section | |
| html_content += """ | |
| <h2>Maximum Likelihood Analysis Results</h2> | |
| <table> | |
| <tr><th>Field</th><th>Value</th></tr> | |
| """ | |
| for row in ml_analysis_table: | |
| html_content += f""" | |
| <tr><td>{row[0]}</td><td>{row[1]}</td></tr> | |
| """ | |
| html_content += """ | |
| </table> | |
| """ | |
| # --- Close HTML --- | |
| html_content += """ | |
| </body> | |
| </html> | |
| """ | |
| # --- Save HTML Report --- | |
| report_filename = f"detailed_report_{self.query_id.replace('/', '_')}.html" | |
| print(f"Attempting to save report to: {os.path.abspath(report_filename)}") | |
| with open(report_filename, 'w') as f: | |
| f.write(html_content) | |
| print(f"✓ Detailed HTML report saved as '{report_filename}'") | |
| return True | |
| except Exception as e: | |
| print(f"Error generating detailed report: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def command_line_interface(): | |
| """Parse command-line arguments and run phylogenetic analysis.""" | |
| parser = argparse.ArgumentParser( | |
| description="Advanced Phylogenetic Tree Analyzer with AI-enhanced similarity matching", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog="Examples:\n %(prog)s -d data.csv -q MH087032 -s 95\n %(prog)s -d data.csv -q MH087032 -s 90 --no-ai --batch query1,query2,query3" | |
| ) | |
| parser.add_argument('-d', '--data', required=True, help='Path to CSV data file') | |
| parser.add_argument('-q', '--query', required=True, help='Query sequence ID or nucleotide sequence') | |
| parser.add_argument('-s', '--similarity', type=float, default=95.0, help='Target similarity percentage (70-99, default: 95)') | |
| parser.add_argument('--no-ai', action='store_true', help='Skip AI model training') | |
| parser.add_argument('--batch', help='Comma-separated list of query IDs for batch processing') | |
| parser.add_argument('--output-dir', default='.', help='Output directory for results') | |
| parser.add_argument('--save-json', action='store_true', help='Save detailed results to JSON') | |
| args = parser.parse_args() | |
| # Validate arguments | |
| if not 70 <= args.similarity <= 99: | |
| print("❌ Similarity percentage must be between 70 and 99.") | |
| sys.exit(1) | |
| if not Path(args.data).exists(): | |
| print(f"❌ Data file not found: {args.data}") | |
| sys.exit(1) | |
| # Initialize analyzer | |
| analyzer = PhylogeneticTreeAnalyzer() | |
| if not analyzer.load_data(args.data): | |
| print("❌ Failed to load data.") | |
| sys.exit(1) | |
| # Train AI model unless disabled | |
| if not args.no_ai: | |
| print("⏳ Training AI model...") | |
| start_time = time.time() | |
| if analyzer.train_ai_model(): | |
| print(f"✅ AI model training completed in {time.time() - start_time:.1f} seconds") | |
| else: | |
| print("⚠️ AI model training failed, continuing with basic analysis") | |
| # Process queries | |
| queries = args.batch.split(',') if args.batch else [args.query] | |
| for query in queries: | |
| query = query.strip() | |
| print(f"🔍 Processing: {query}") | |
| if not analyzer.find_query_sequence(query): | |
| print(f"❌ Query not found: {query}") | |
| continue | |
| matched_ids, actual_percentage = analyzer.find_similar_sequences(args.similarity) | |
| if not matched_ids: | |
| print(f"❌ No similar sequences found for {query}") | |
| continue | |
| analyzer.build_tree_structure_with_ml_safe(matched_ids) | |
| fig = analyzer.create_interactive_tree(matched_ids, actual_percentage) | |
| if fig: | |
| html_filename = f"phylogenetic_tree_{query.replace('/', '_')}_interactive.html" | |
| fig.write_html(html_filename) | |
| print(f"📄 Interactive HTML saved: {html_filename}") | |
| analyzer.generate_detailed_report(matched_ids, actual_percentage) | |
| print(f"📄 Detailed HTML report saved: detailed_report_{query.replace('/', '_')}.html") | |
| print(f"✅ Analysis completed for {query}") | |
| def main(): | |
| """Run interactive phylogenetic analysis with user input.""" | |
| print("\n" + "="*70) | |
| print("🧬 PHYLOGENETIC TREE ANALYZER - ADVANCED ML-BASED ANALYSIS") | |
| print("Version 2.0 | AI-Enhanced Similarity Matching") | |
| print("="*70) | |
| analyzer = PhylogeneticTreeAnalyzer() | |
| # Load data | |
| data_file = "f cleaned.csv" | |
| while not Path(data_file).exists() or not analyzer.load_data(data_file): | |
| print(f"❌ File not found or invalid: {data_file}") | |
| data_file = input("Enter valid data file path: ").strip() | |
| if not data_file: | |
| print("❌ Analysis cancelled.") | |
| return | |
| # Train AI model | |
| print("⏳ Training AI model...") | |
| start_time = time.time() | |
| if analyzer.train_ai_model(): | |
| print(f"✅ AI model training completed in {time.time() - start_time:.1f} seconds") | |
| else: | |
| print("⚠️ AI model training failed, continuing with basic analysis") | |
| # Get query sequence | |
| while True: | |
| query_input = input("\nEnter query sequence or ID (min 10 nucleotides): ").strip() | |
| if analyzer.find_query_sequence(query_input): | |
| break | |
| retry = input("❌ Invalid input. Try again? (y/n): ").strip().lower() | |
| if retry != 'y': | |
| print("👋 Analysis cancelled.") | |
| return | |
| # Set similarity percentage | |
| while True: | |
| try: | |
| similarity_input = input("Enter target similarity percentage (1-99) [85]: ").strip() | |
| target_percentage = float(similarity_input) if similarity_input else 85.0 | |
| if 1 <= target_percentage <= 99: | |
| analyzer.matching_percentage = target_percentage | |
| break | |
| print("❌ Please enter a percentage between 1 and 99.") | |
| except ValueError: | |
| print("❌ Please enter a valid number.") | |
| # Find similar sequences | |
| print(f"⏳ Analyzing sequences for {target_percentage}% similarity...") | |
| start_time = time.time() | |
| matched_ids, actual_percentage = analyzer.find_similar_sequences(target_percentage) | |
| if not matched_ids: | |
| print(f"❌ No similar sequences found at {target_percentage}% similarity.") | |
| return | |
| analyzer.matched_sequences = matched_ids | |
| analyzer.actual_percentage = actual_percentage | |
| print(f"✅ Similarity analysis completed in {time.time() - start_time:.1f} seconds") | |
| # Build tree structure | |
| print("⏳ Building phylogenetic tree structure...") | |
| start_time = time.time() | |
| tree_structure = analyzer.build_tree_structure_with_ml_safe(matched_ids) | |
| if not tree_structure: | |
| print("❌ Failed to build tree structure.") | |
| return | |
| print(f"✅ Tree structure built in {time.time() - start_time:.1f} seconds") | |
| # Create visualization and save HTML | |
| print("⏳ Creating interactive visualization...") | |
| start_time = time.time() | |
| fig = analyzer.create_interactive_tree(matched_ids, actual_percentage) | |
| if not fig: | |
| print("❌ Visualization creation failed.") | |
| return | |
| html_filename = "phylogenetic_tree_interactive.html" | |
| fig.write_html(html_filename) | |
| print(f"📄 Interactive HTML saved: {html_filename}") | |
| # Generate detailed report | |
| print("⏳ Generating detailed report...") | |
| start_time = time.time() | |
| if analyzer.generate_detailed_report(matched_ids, actual_percentage): | |
| print(f"✅ Detailed report generated in {time.time() - start_time:.1f} seconds") | |
| print(f"\n🎉 Analysis completed successfully!") | |
| print(f" Query ID: {analyzer.query_id}") | |
| print(f" Query sequence length: {len(analyzer.query_sequence)} nucleotides") | |
| print(f" Similar sequences found: {len(matched_ids)}") | |
| print(f" Actual similarity percentage: {actual_percentage:.1f}%") | |
| print(f" HTML visualization file: {html_filename}") | |
| print(f" HTML report file: detailed_report_{analyzer.query_id.replace('/', '_')}.html") | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except KeyboardInterrupt: | |
| print("\n👋 Goodbye!") | |
| sys.exit(0) | |
| except Exception as e: | |
| print(f"\n❌ Unexpected error: {e}") | |
| sys.exit(1) |