# -*- coding: utf-8 -*- """ML simplified tree.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1LiDjip-h70ilIex9PedpWCZARWglija7 """ # Commented out IPython magic to ensure Python compatibility. import pandas as pd import numpy as np import plotly.graph_objects as go import plotly.offline as pyo from plotly.subplots import make_subplots from Bio import Phylo, SeqIO, AlignIO from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor from Bio.Align import MultipleSeqAlignment from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord from sklearn.feature_extraction.text import CountVectorizer from sklearn.metrics.pairwise import cosine_similarity from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder import warnings import os import sys from typing import Dict, List, Tuple, Optional, Any import json import re from scipy.optimize import minimize from scipy.spatial.distance import pdist, squareform from Bio.Phylo import BaseTree import itertools from collections import defaultdict, deque import argparse import time from pathlib import Path warnings.filterwarnings('ignore') class PhylogeneticTreeAnalyzer: def __init__(self): self.data = None self.query_sequence = None self.query_id = None self.matching_percentage = 95.0 self.actual_percentage = None self.matched_sequences = [] self.tree_structure = {} self.similarity_scores = {} self.ai_model = None self.label_encoder = LabelEncoder() # ML-specific attributes self.ml_tree = None self.ml_alignment = None self.ml_results = {} self.horizontal_line_tracker = [] # Track horizontal lines with verticals self.query_ml_group = None # Track which ML group contains the query self.base_horizontal_length = 1.2 # Base length for horizontal lines def load_data(self, data_file: str): try: self.data = pd.read_csv(data_file) # required_columns = ['Accession Number', 'ML', 'Genotype', 'Host', # 'Country', 'Isolate', 'Year', 'F-gene'] # missing_columns = [col for col in self.data.columns if col not in required_columns] # Corrected check for missing columns # if missing_columns: # print(f"Error: Missing required columns: {missing_columns}") # return False print(f"✓ Data loaded successfully: {len(self.data)} sequences") print(f"✓ ML Groups found: {self.data['ML'].nunique()}") print(f"✓ Genotypes found: {self.data['Genotype'].nunique()}") return True except Exception as e: print(f"Error loading data: {e}") return False def calculate_f_gene_similarity(self, seq1: str, seq2: str) -> float: try: # Handle empty or None sequences if not seq1 or not seq2: return 0.0 # Convert to uppercase and remove non-nucleotide characters seq1 = re.sub(r'[^ATGC]', '', str(seq1).upper()) seq2 = re.sub(r'[^ATGC]', '', str(seq2).upper()) if len(seq1) == 0 or len(seq2) == 0: return 0.0 # Use k-mer analysis for similarity calculation k = 5 # 5-mer analysis kmers1 = set([seq1[i:i+k] for i in range(len(seq1)-k+1) if len(seq1[i:i+k]) == k]) kmers2 = set([seq2[i:i+k] for i in range(len(seq2)-k+1) if len(seq2[i:i+k]) == k]) if len(kmers1) == 0 and len(kmers2) == 0: return 100.0 elif len(kmers1) == 0 or len(kmers2) == 0: return 0.0 # Calculate Jaccard similarity intersection = len(kmers1.intersection(kmers2)) union = len(kmers1.union(kmers2)) similarity = (intersection / union) * 100 if union > 0 else 0.0 return round(similarity, 2) except Exception as e: print(f"Error calculating similarity: {e}") return 0.0 def train_ai_model(self): try: # Skip training if insufficient data if len(self.data) < 10: # Require minimum 10 samples print("⚠️ Insufficient data to train AI model (min 10 samples required)", flush=True) return False print("🤖 Training AI model for sequence analysis...", flush=True) # Prepare features from F-gene sequences f_gene_sequences = self.data['F-gene'].fillna('').astype(str) # Create k-mer features (3-mers to 6-mers) features = [] for seq in f_gene_sequences: seq_clean = re.sub(r'[^ATGC]', '', seq.upper()) if len(seq_clean) < 3: features.append([0] * 100) # Placeholder for short sequences continue feature_vector = [] # 3-mers kmers_3 = [seq_clean[i:i+3] for i in range(len(seq_clean)-2)] kmer_counts_3 = {kmer: kmers_3.count(kmer) for kmer in set(kmers_3)} # 4-mers kmers_4 = [seq_clean[i:i+4] for i in range(len(seq_clean)-3)] kmer_counts_4 = {kmer: kmers_4.count(kmer) for kmer in set(kmers_4)} # Create feature vector (top 50 3-mers + top 50 4-mers) all_3mers = [''.join(p) for p in __import__('itertools').product('ATGC', repeat=3)] all_4mers = [''.join(p) for p in __import__('itertools').product('ATGC', repeat=4)] feature_vector.extend([kmer_counts_3.get(kmer, 0) for kmer in all_3mers[:50]]) feature_vector.extend([kmer_counts_4.get(kmer, 0) for kmer in all_4mers[:50]]) features.append(feature_vector) # Prepare target labels (ML groups) targets = self.label_encoder.fit_transform(self.data['ML'].fillna('Unknown')) # Skip if only 1 class if len(np.unique(targets)) < 2: print("⚠️ Need at least 2 distinct classes for training", flush=True) return False # Train Random Forest model X = np.array(features) y = targets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) self.ai_model = RandomForestClassifier(n_estimators=100, random_state=42) self.ai_model.fit(X_train, y_train) # Calculate accuracy accuracy = self.ai_model.score(X_test, y_test) print(f"✓ AI model trained successfully with accuracy: {accuracy:.2%}", flush=True) return True except Exception as e: print(f"🚨 CRITICAL training error: {e}", flush=True) import traceback traceback.print_exc() return False def find_query_sequence(self, query_input: str) -> bool: """ Modified to accept any sequence input, not just those existing in the dataset. """ try: # Check if input is an accession number from the dataset if query_input in self.data['Accession Number'].values: self.query_id = query_input query_row = self.data[self.data['Accession Number'] == query_input].iloc[0] self.query_sequence = query_row['F-gene'] print(f"✓ Query sequence found by ID: {query_input}") return True # Check if input is a nucleotide sequence query_clean = re.sub(r'[^ATGC]', '', str(query_input).upper()) # Accept any sequence with reasonable length (even short ones for testing) if len(query_clean) >= 10: # Minimum sequence length (reduced from 50) # For sequences not in dataset, create a unique identifier if query_input not in self.data['Accession Number'].values: # Generate a unique query ID for novel sequences self.query_id = f"QUERY_{hash(query_clean) % 100000:05d}" self.query_sequence = query_clean print(f"✓ Novel query sequence accepted with ID: {self.query_id}") print(f" Sequence length: {len(query_clean)} nucleotides") return True else: # If somehow it matches an accession but wasn't caught above self.query_id = query_input self.query_sequence = query_clean print(f"✓ Query sequence processed: {query_input}") return True # If sequence is too short or invalid if len(query_clean) < 10: print(f"✗ Query sequence too short. Minimum length: 10 nucleotides (provided: {len(query_clean)})") else: print(f"✗ Invalid sequence format. Please provide nucleotides (A, T, G, C) or valid accession number") return False except Exception as e: print(f"Error processing query sequence: {e}") return False def find_similar_sequences(self, target_percentage: float) -> Tuple[List[str], float]: """ Modified to work with any query sequence, including novel ones not in the dataset. """ try: print(f"🔍 Finding sequences with {target_percentage}% similarity to query...") similarities = [] # Calculate similarity between query and all sequences in dataset for idx, row in self.data.iterrows(): # Skip if this is the same sequence (only relevant for existing accession numbers) if hasattr(self, 'query_id') and row['Accession Number'] == self.query_id: continue try: similarity = self.calculate_f_gene_similarity(self.query_sequence, row['F-gene']) similarities.append({ 'id': row['Accession Number'], 'similarity': similarity, 'ml': row['ML'] if 'ML' in row else 'Unknown', 'genotype': row['Genotype'] if 'Genotype' in row else 'Unknown' }) except Exception as seq_error: print(f"⚠ Skipping sequence {row['Accession Number']}: {seq_error}") continue if not similarities: print("❌ No valid sequences found for comparison") return [], target_percentage # Sort by similarity (highest first) similarities.sort(key=lambda x: x['similarity'], reverse=True) # Find sequences within target percentage range (±2%) target_range = 2.0 candidates = [s for s in similarities if abs(s['similarity'] - target_percentage) <= target_range] if not candidates: # If no exact matches, find sequences with closest similarity closest_sim = min(similarities, key=lambda x: abs(x['similarity'] - target_percentage)) actual_percentage = closest_sim['similarity'] # Get sequences within ±1% of the closest similarity candidates = [s for s in similarities if abs(s['similarity'] - actual_percentage) <= 1.0] print(f"⚠ No sequences found at exactly {target_percentage}%. Using closest: {actual_percentage:.1f}%") else: actual_percentage = target_percentage # Limit results to prevent overwhelming visualization (optional) max_results = 50 # Adjust as needed if len(candidates) > max_results: candidates = candidates[:max_results] print(f"⚠ Limited results to top {max_results} matches for better visualization") # Store similarity scores for later use self.similarity_scores = {} # Reset similarity scores for candidate in candidates: self.similarity_scores[candidate['id']] = candidate['similarity'] matched_ids = [c['id'] for c in candidates] # Show some statistics if similarities: max_sim = max(similarities, key=lambda x: x['similarity'])['similarity'] min_sim = min(similarities, key=lambda x: x['similarity'])['similarity'] avg_sim = sum(s['similarity'] for s in similarities) / len(similarities) print(f"✓ Found {len(matched_ids)} sequences at ~{actual_percentage:.1f}% similarity") print(f" Similarity range in dataset: {min_sim:.1f}% - {max_sim:.1f}% (avg: {avg_sim:.1f}%)") return matched_ids, actual_percentage except Exception as e: print(f"Error finding similar sequences: {e}") return [], target_percentage def build_tree_structure(self, matched_ids: List[str]) -> Dict: try: print("🌳 Building normalized horizontal tree structure...") # Initialize tree structure tree_structure = { 'root': { 'name': 'Root', 'type': 'root', 'children': {}, 'x': 0, 'y': 0, 'has_vertical_attachment': False, 'extension_level': 0 } } # Group sequences by ML and Genotype ml_groups = {} for idx, row in self.data.iterrows(): ml_group = row['ML'] genotype = row['Genotype'] seq_id = row['Accession Number'] if ml_group not in ml_groups: ml_groups[ml_group] = {} if genotype not in ml_groups[ml_group]: ml_groups[ml_group][genotype] = [] ml_groups[ml_group][genotype].append({ 'id': seq_id, 'data': row.to_dict(), 'is_query': seq_id == self.query_id, 'is_matched': seq_id in matched_ids, 'similarity': self.similarity_scores.get(seq_id, 0.0) }) # Normalize ML group names and organize normalized_ml_groups = self._normalize_ml_groups(ml_groups) # Build normalized ML level - horizontal layout with progressive extensions self._build_normalized_ml_nodes(tree_structure, normalized_ml_groups, matched_ids) self.tree_structure = tree_structure print(f"✓ Normalized horizontal tree structure built") return tree_structure except Exception as e: print(f"Error building tree structure: {e}") return {} def _normalize_ml_groups(self, ml_groups: Dict) -> Dict: """Normalize ML group names and organize hierarchically""" try: normalized_groups = {} for ml_name, genotypes in ml_groups.items(): # Extract base ML name if ml_name.startswith('UNCL'): base_ml = 'UNCL' elif '.' in ml_name and any(char.isdigit() for char in ml_name): # For names like XII.1.2, XII.1, etc., extract the base (XII) base_ml = ml_name.split('.')[0] else: base_ml = ml_name # Initialize normalized group structure if base_ml not in normalized_groups: normalized_groups[base_ml] = { 'full_ml_groups': {}, 'representative_sequences': [], 'has_special_sequences': False } # Check if this ML group has query or matched sequences has_special = any( any(seq['is_query'] or seq['is_matched'] for seq in sequences) for sequences in genotypes.values() ) if has_special: normalized_groups[base_ml]['has_special_sequences'] = True normalized_groups[base_ml]['full_ml_groups'][ml_name] = genotypes else: # Add as representative (limit to 2 representatives) if len(normalized_groups[base_ml]['representative_sequences']) < 2: # Get 1-2 representative sequences from this ML group for genotype, sequences in list(genotypes.items())[:2]: if len(normalized_groups[base_ml]['representative_sequences']) < 2: normalized_groups[base_ml]['representative_sequences'].extend(sequences[:1]) return normalized_groups except Exception as e: print(f"Error normalizing ML groups: {e}") return {} def _build_normalized_ml_nodes(self, tree_structure: Dict, normalized_ml_groups: Dict, matched_ids: List[str]): """Build normalized ML nodes with equal spacing and progressive horizontal extensions""" try: # Reset horizontal line tracker self.horizontal_line_tracker = [] # Identify which ML group contains the query self._identify_query_ml_group(normalized_ml_groups) # Calculate equal spacing for all ML groups ml_positions = self._calculate_dynamic_ml_positions(normalized_ml_groups) # Mark root as having vertical attachment if it has multiple children root_has_vertical = len(normalized_ml_groups) > 1 tree_structure['root']['has_vertical_attachment'] = root_has_vertical for ml_idx, (base_ml, ml_data) in enumerate(normalized_ml_groups.items()): y_pos = ml_positions[ml_idx] # Determine if this ML node will have vertical attachments has_vertical = ml_data['has_special_sequences'] and len(ml_data['full_ml_groups']) > 1 # Check if this ML group contains the query contains_query = (base_ml == self.query_ml_group) # Calculate horizontal line length based on connections and query presence horizontal_length = self._determine_horizontal_line_length( 'normalized_ml_group', has_vertical, contains_query ) x_pos = horizontal_length # Create normalized ML node tree_structure['root']['children'][base_ml] = { 'name': base_ml, 'type': 'normalized_ml_group', 'children': {}, 'x': x_pos, 'y': y_pos, 'has_special_sequences': ml_data['has_special_sequences'], 'has_vertical_attachment': has_vertical, 'horizontal_line_length': horizontal_length, 'contains_query': contains_query } if ml_data['has_special_sequences']: # Build full ML nodes for groups with special sequences self._build_full_ml_nodes( tree_structure['root']['children'][base_ml], ml_data['full_ml_groups'], y_pos, matched_ids, x_pos ) else: # Add representative sequences directly self._add_representative_sequences( tree_structure['root']['children'][base_ml], ml_data['representative_sequences'], y_pos, x_pos ) except Exception as e: print(f"Error building normalized ML nodes: {e}") def _calculate_dynamic_ml_positions(self, normalized_ml_groups: Dict) -> List[float]: """Calculate equal Y positions for all ML groups regardless of content""" try: ml_count = len(normalized_ml_groups) if ml_count == 0: return [] if ml_count == 1: return [0.0] # Equal spacing between all ML nodes total_spacing = (ml_count - 1) * 2.0 # 2.0 units between each ML node start_y = -total_spacing / 2 positions = [] for i in range(ml_count): positions.append(start_y + i * 2.0) return positions except Exception as e: print(f"Error calculating dynamic positions: {e}") return list(range(len(normalized_ml_groups))) def _build_full_ml_nodes(self, normalized_ml_node: Dict, full_ml_groups: Dict, base_y: float, matched_ids: List[str], parent_x: float): """Build full ML nodes with genotypes for groups containing special sequences""" try: # Calculate equal positions for full ML groups full_ml_positions = self._calculate_full_ml_positions(full_ml_groups, base_y) for ml_idx, (full_ml_name, genotypes) in enumerate(full_ml_groups.items()): y_pos = full_ml_positions[ml_idx] # Determine if this full ML node will have vertical attachments special_genotypes_count = sum(1 for genotype, sequences in genotypes.items() if any(seq['is_query'] or seq['is_matched'] for seq in sequences)) has_vertical = special_genotypes_count > 1 # Check if this full ML group contains the query contains_query = any( any(seq['is_query'] for seq in sequences) for sequences in genotypes.values() ) # Calculate horizontal line length horizontal_length = self._determine_horizontal_line_length( 'full_ml_group', has_vertical, contains_query ) x_pos = parent_x + horizontal_length # Create full ML node normalized_ml_node['children'][full_ml_name] = { 'name': full_ml_name, 'type': 'full_ml_group', 'children': {}, 'x': x_pos, 'y': y_pos, 'sequences_count': sum(len(seqs) for seqs in genotypes.values()), 'has_vertical_attachment': has_vertical, 'horizontal_line_length': horizontal_length, 'contains_query': contains_query } # Build genotype nodes self._build_genotype_nodes( normalized_ml_node['children'][full_ml_name], genotypes, y_pos, matched_ids, x_pos ) except Exception as e: print(f"Error building full ML nodes: {e}") def _calculate_full_ml_positions(self, full_ml_groups: Dict, base_y: float) -> List[float]: """Calculate equal positions for full ML groups""" try: ml_count = len(full_ml_groups) if ml_count <= 1: return [base_y] # Equal spacing for full ML groups spacing = 1.5 # Fixed spacing between full ML groups start_y = base_y - (spacing * (ml_count - 1)) / 2 positions = [] for i in range(ml_count): positions.append(start_y + i * spacing) return positions except Exception as e: print(f"Error calculating full ML positions: {e}") return [base_y] * len(full_ml_groups) def _build_genotype_nodes(self, full_ml_node: Dict, genotypes: Dict, base_y: float, matched_ids: List[str], parent_x: float): """Build genotype nodes with sequences - horizontal line length based on sequence count""" try: # Filter genotypes with special sequences special_genotypes = [] for genotype, sequences in genotypes.items(): if any(seq['is_query'] or seq['is_matched'] for seq in sequences): special_genotypes.append((genotype, sequences)) if not special_genotypes: return # Calculate equal genotype positions (vertical positioning remains equal) genotype_positions = self._calculate_genotype_positions(special_genotypes, base_y) # Calculate sequence counts for each genotype to determine horizontal line lengths genotype_sequence_counts = [] for genotype, sequences in special_genotypes: special_sequences = [seq for seq in sequences if seq['is_query'] or seq['is_matched']] genotype_sequence_counts.append((genotype, sequences, len(special_sequences))) for gt_idx, (genotype, sequences, sequence_count) in enumerate(genotype_sequence_counts): y_pos = genotype_positions[gt_idx] # Determine if this genotype will have vertical attachments special_sequences = [seq for seq in sequences if seq['is_query'] or seq['is_matched']] has_vertical = len(special_sequences) > 1 # Check if this genotype contains the query contains_query = any(seq['is_query'] for seq in sequences) # Calculate horizontal line length based on sequence count horizontal_length = self._determine_genotype_horizontal_line_length( sequence_count, has_vertical, contains_query ) x_pos = parent_x + horizontal_length # Create genotype node full_ml_node['children'][genotype] = { 'name': genotype, 'type': 'genotype', 'children': {}, 'x': x_pos, 'y': y_pos, 'sequences': sequences, 'has_vertical_attachment': has_vertical, 'horizontal_line_length': horizontal_length, 'contains_query': contains_query, 'sequence_count': sequence_count # Store for reference } # Add sequences horizontally self._add_sequences_horizontal( full_ml_node['children'][genotype], sequences, y_pos, x_pos ) except Exception as e: print(f"Error building genotype nodes: {e}") def _determine_genotype_horizontal_line_length(self, sequence_count: int, has_vertical: bool, contains_query: bool = False) -> float: """Determine horizontal line length for genotype nodes based on sequence count""" try: base_length = self.base_horizontal_length # Special case: Genotype containing query sequence gets additional length query_bonus = 0.5 if contains_query else 0.0 # Calculate length based on sequence count # More sequences = longer horizontal line if sequence_count <= 1: # Single sequence length_multiplier = 1.0 elif sequence_count <= 3: # 2-3 sequences length_multiplier = 1.6 elif sequence_count <= 5: # 4-5 sequences length_multiplier = 2.3 elif sequence_count <= 8: # 6-8 sequences length_multiplier = 6.0 else: # More than 8 sequences length_multiplier = 6.0 # Calculate final length calculated_length = base_length * length_multiplier + query_bonus return calculated_length except Exception as e: print(f"Error determining genotype horizontal line length: {e}") return self.base_horizontal_length def _calculate_genotype_positions(self, special_genotypes: List, base_y: float) -> List[float]: """Calculate equal positions for genotypes""" try: genotype_count = len(special_genotypes) if genotype_count <= 1: return [base_y] # Equal spacing for genotypes spacing = 1.0 # Fixed spacing between genotypes start_y = base_y - (spacing * (genotype_count - 1)) / 2 positions = [] for i in range(genotype_count): positions.append(start_y + i * spacing) return positions except Exception as e: print(f"Error calculating genotype positions: {e}") return [base_y] * len(special_genotypes) def _add_representative_sequences(self, normalized_ml_node: Dict, representative_sequences: List[Dict], base_y: float, parent_x: float): """Add representative sequences directly to normalized ML node""" try: if not representative_sequences: return # Calculate horizontal line length for representative sequences # Representative sequences get a standard length (not similarity-based since they're not matched) has_vertical = len(representative_sequences) > 1 horizontal_length = self._determine_horizontal_line_length('representative', has_vertical) x_pos = parent_x + horizontal_length if len(representative_sequences) == 1: seq = representative_sequences[0] normalized_ml_node['children'][f"{seq['id']}_rep"] = { 'name': f"{seq['id']} (Rep)", 'type': 'representative_sequence', 'data': seq, 'x': x_pos, 'y': base_y, 'has_vertical_attachment': False, 'horizontal_line_length': horizontal_length } else: # Equal spacing for multiple representative sequences positions = self._calculate_sequence_positions(representative_sequences, base_y) for idx, seq in enumerate(representative_sequences): normalized_ml_node['children'][f"{seq['id']}_rep"] = { 'name': f"{seq['id']} (Rep)", 'type': 'representative_sequence', 'data': seq, 'x': x_pos, 'y': positions[idx], 'has_vertical_attachment': False, 'horizontal_line_length': horizontal_length } except Exception as e: print(f"Error adding representative sequences: {e}") def _add_sequences_horizontal(self, genotype_node: Dict, sequences: List[Dict], base_y: float, parent_x: float): """Add sequences horizontally with similarity-based line lengths""" try: # Define the query line length as the reference (100%) query_line_length = 3.0 # Base length for query sequence (100%) # Separate query and matched sequences query_sequences = [seq for seq in sequences if seq['is_query']] matched_sequences = [seq for seq in sequences if seq['is_matched'] and not seq['is_query']] all_special_sequences = query_sequences + matched_sequences if len(all_special_sequences) == 1: # Single sequence - direct line with similarity-based length sequence = all_special_sequences[0] line_length = self._calculate_similarity_based_line_length(sequence, query_line_length) x_pos = parent_x + line_length genotype_node['children'][sequence['id']] = { 'name': f"{sequence['id']}{' (' + str(sequence['similarity']) + '%)' if sequence['is_matched'] else ''}", 'type': 'sequence', 'data': sequence, 'x': x_pos, 'y': base_y, 'has_vertical_attachment': False, 'similarity_line_length': line_length } else: # Multiple sequences - equal vertical distribution with similarity-based horizontal lengths sequence_positions = self._calculate_sequence_positions(all_special_sequences, base_y) for seq_idx, sequence in enumerate(all_special_sequences): line_length = self._calculate_similarity_based_line_length(sequence, query_line_length) x_pos = parent_x + line_length genotype_node['children'][sequence['id']] = { 'name': f"{sequence['id']}{' (' + str(sequence['similarity']) + '%)' if sequence['is_matched'] else ''}", 'type': 'sequence', 'data': sequence, 'x': x_pos, 'y': sequence_positions[seq_idx], 'has_vertical_attachment': False, 'similarity_line_length': line_length } except Exception as e: print(f"Error adding sequences horizontally: {e}") def _calculate_similarity_based_line_length(self, sequence: Dict, query_line_length: float) -> float: """Calculate line length based on similarity percentage relative to query""" try: if sequence['is_query']: # Query sequence gets 100% length return query_line_length elif sequence['is_matched']: # Matched sequences get length proportional to their similarity similarity = sequence['similarity'] # Convert similarity percentage to proportional length proportional_length = (similarity / 100.0) * query_line_length # Ensure minimum length for visibility min_length = query_line_length * 0.2 # Minimum 20% of query length return max(proportional_length, min_length) else: # Other sequences get a standard length (50% of query) return query_line_length * 0.5 except Exception as e: print(f"Error calculating similarity-based line length: {e}") return query_line_length * 0.5 def _calculate_sequence_positions(self, sequences: List[Dict], base_y: float) -> List[float]: """Calculate equal positions for sequences""" try: seq_count = len(sequences) if seq_count <= 1: return [base_y] # Equal spacing for sequences spacing = 0.8 # Fixed spacing between sequences start_y = base_y - (spacing * (seq_count - 1)) / 2 positions = [] for i in range(seq_count): positions.append(start_y + i * spacing) return positions except Exception as e: print(f"Error calculating sequence positions: {e}") return [base_y] * len(sequences) def _determine_horizontal_line_length(self, node_type: str, has_vertical: bool, contains_query: bool = False) -> float: """Determine horizontal line length based on node type and connections""" try: base_length = self.base_horizontal_length # Special case: ML group containing query sequence gets much longer line if contains_query and node_type == 'normalized_ml_group': return base_length * 2.5 # Much longer for query ML group # If this node has a vertical line attachment (connects to multiple children) if has_vertical: # Find the current longest horizontal line with vertical current_max = base_length for tracked_length in self.horizontal_line_tracker: if tracked_length > current_max: current_max = tracked_length # Make this line incrementally longer new_length = current_max + 0.3 self.horizontal_line_tracker.append(new_length) return new_length else: # Direct connection (no vertical), use base length return base_length except Exception as e: print(f"Error determining horizontal line length: {e}") return self.base_horizontal_length def _identify_query_ml_group(self, normalized_ml_groups: Dict): """Identify which ML group contains the query sequence""" try: for base_ml, ml_data in normalized_ml_groups.items(): if ml_data['has_special_sequences']: for full_ml_name, genotypes in ml_data['full_ml_groups'].items(): for genotype, sequences in genotypes.items(): if any(seq['is_query'] for seq in sequences): self.query_ml_group = base_ml return except Exception as e: print(f"Error identifying query ML group: {e}") def _identify_query_ml_group(self, normalized_ml_groups: Dict): """Identify which ML group contains the query sequence""" try: for base_ml, ml_data in normalized_ml_groups.items(): if ml_data['has_special_sequences']: for full_ml_name, genotypes in ml_data['full_ml_groups'].items(): for genotype, sequences in genotypes.items(): if any(seq['is_query'] for seq in sequences): self.query_ml_group = base_ml return except Exception as e: print(f"Error identifying query ML group: {e}") def _calculate_sequence_x_position_horizontal(self, sequence: Dict, max_similarity: float) -> float: """Calculate X position based on similarity percentage for horizontal layout""" # This function is now replaced by _calculate_similarity_based_line_length # Keeping for backward compatibility, but the new approach is used in _add_sequences_horizontal base_x = 0 # Relative to parent genotype node query_line_length = 3.0 # Reference length for query (100%) if sequence['is_query']: return base_x + query_line_length # 100% length for query elif sequence['is_matched']: # Line length varies based on similarity percentage similarity = sequence['similarity'] proportional_length = (similarity / 100.0) * query_line_length min_length = query_line_length * 0.2 # Minimum 20% of query length return base_x + max(proportional_length, min_length) else: return base_x + (query_line_length * 0.5) # 50% length for other sequences def create_interactive_tree(self, matched_ids: List[str], actual_percentage: float): try: print("🎨 Creating horizontal interactive tree visualization...") # Prepare data for plotting edge_x = [] edge_y = [] node_x = [] node_y = [] node_colors = [] node_text = [] node_hover = [] node_sizes = [] # Updated color scheme for new node types colors = { 'root': '#FF0000', # Red for root 'normalized_ml_group': '#FFB6C1', # Light pink for normalized ML groups 'full_ml_group': '#FF69B4', # Hot pink for full ML groups 'genotype': '#FFD700', # Gold for genotypes 'representative_sequence': '#FFA500', # Orange for representative sequences 'query_sequence': '#4B0082', # Dark purple for query 'matched_sequence': '#6A5ACD', # Slate blue for matched 'other_sequence': '#87CEEB' # Sky blue for others } def add_horizontal_edges(parent_x, parent_y, children_dict): """Add horizontal connecting lines with proper vertical line sizing""" if not children_dict: return children_list = list(children_dict.values()) if len(children_list) == 1: # Single child - direct horizontal line child = children_list[0] edge_x.extend([parent_x, child['x'], None]) edge_y.extend([parent_y, child['y'], None]) else: # Multiple children - horizontal line with vertical distribution # Calculate the intermediate x position (where vertical line will be) child_x_positions = [child['x'] for child in children_list] min_child_x = min(child_x_positions) intermediate_x = parent_x + (min_child_x - parent_x) * 0.8 # 80% of the way to nearest child # Horizontal line to intermediate point edge_x.extend([parent_x, intermediate_x, None]) edge_y.extend([parent_y, parent_y, None]) # Calculate vertical line range to fit exactly all children child_y_positions = [child['y'] for child in children_list] min_y, max_y = min(child_y_positions), max(child_y_positions) # Vertical line sized exactly to fit all children edge_x.extend([intermediate_x, intermediate_x, None]) edge_y.extend([min_y, max_y, None]) # Horizontal lines from vertical line to each child for child in children_list: edge_x.extend([intermediate_x, child['x'], None]) edge_y.extend([child['y'], child['y'], None]) def get_node_color_and_size(node): """Determine node color and size based on type and content""" if node['type'] == 'sequence': if node['data']['is_query']: return colors['query_sequence'], 10 # Reduced size for compactness elif node['data']['is_matched']: return colors['matched_sequence'], 8 else: return colors['other_sequence'], 6 elif node['type'] == 'representative_sequence': return colors['representative_sequence'], 7 elif node['type'] == 'normalized_ml_group': # Larger size if it has special sequences size = 9 if node.get('has_special_sequences', False) else 7 return colors['normalized_ml_group'], size elif node['type'] == 'full_ml_group': return colors['full_ml_group'], 8 elif node['type'] == 'genotype': return colors['genotype'], 7 else: return colors.get(node['type'], '#000000'), 7 def create_node_text(node): """Create appropriate text label for each node type""" if node['type'] == 'sequence': if node['data']['is_matched'] and not node['data']['is_query']: return f"{node['name']}" else: return node['name'] elif node['type'] == 'representative_sequence': return node['name'] elif node['type'] == 'normalized_ml_group': # Add indicator if it has special sequences suffix = " *" if node.get('has_special_sequences', False) else "" return f"{node['name']}{suffix}" else: return node['name'] def create_hover_text(node): """Create detailed hover text for each node type""" if node['type'] == 'sequence': data = node['data']['data'] hover_text = ( f"{node['name']}
" f"Type: {'Query Sequence' if node['data']['is_query'] else 'Matched Sequence' if node['data']['is_matched'] else 'Other Sequence'}
" f"ML Group: {data.get('ML', 'N/A')}
" f"Genotype: {data.get('Genotype', 'N/A')}
" f"Host: {data.get('Host', 'N/A')}
" f"Country: {data.get('Country', 'N/A')}
" f"Isolate: {data.get('Isolate', 'N/A')}
" f"Year: {data.get('Year', 'N/A')}" ) if node['data']['is_matched']: hover_text += f"
Similarity: {node['data']['similarity']}%" elif node['type'] == 'representative_sequence': data = node['data']['data'] hover_text = ( f"{node['name']}
" f"Type: Representative Sequence
" f"ML Group: {data.get('ML', 'N/A')}
" f"Genotype: {data.get('Genotype', 'N/A')}
" f"Host: {data.get('Host', 'N/A')}
" f"Country: {data.get('Country', 'N/A')}" ) elif node['type'] == 'normalized_ml_group': hover_text = f"{node['name']}
Type: Normalized ML Group" if node.get('has_special_sequences', False): hover_text += "
Contains query/matched sequences" else: hover_text += "
Representative sequences only" elif node['type'] == 'full_ml_group': hover_text = f"{node['name']}
Type: Full ML Group" if 'sequences_count' in node: hover_text += f"
Total Sequences: {node['sequences_count']}" elif node['type'] == 'genotype': hover_text = f"{node['name']}
Type: Genotype" if 'sequences' in node: special_count = sum(1 for seq in node['sequences'] if seq['is_query'] or seq['is_matched']) hover_text += f"
Special Sequences: {special_count}/{len(node['sequences'])}" else: hover_text = f"{node['name']}
Type: {node['type'].replace('_', ' ').title()}" return hover_text def add_node_and_edges(node, parent_x=None, parent_y=None): """Recursively add nodes and edges to the plot with equal spacing structure.""" x, y = node['x'], node['y'] node_x.append(x) node_y.append(y) # Get node color and size color, size = get_node_color_and_size(node) node_colors.append(color) node_sizes.append(size) # Create node text and hover node_text.append(create_node_text(node)) node_hover.append(create_hover_text(node)) # Process children with equal spacing structure if 'children' in node and node['children']: add_horizontal_edges(x, y, node['children']) for child in node['children'].values(): add_node_and_edges(child, x, y) # Build the plot data starting from root root_node = self.tree_structure['root'] add_node_and_edges(root_node) # Add horizontal edges for root level if root_node['children']: add_horizontal_edges(root_node['x'], root_node['y'], root_node['children']) # Create the figure fig = go.Figure() # Add edges fig.add_trace(go.Scatter( x=edge_x, y=edge_y, mode='lines', line=dict(width=1, color='gray', dash='solid'), # Thinner lines for compactness hoverinfo='none', showlegend=False, name='Edges' )) # Add nodes fig.add_trace(go.Scatter( x=node_x, y=node_y, mode='markers+text', marker=dict( size=node_sizes, color=node_colors, line=dict(width=1, color='black'), # Thinner borders opacity=0.85 ), text=node_text, textposition="middle right", textfont=dict(size=9, color="black"), # Smaller font for compactness hoverinfo='text', hovertext=node_hover, showlegend=False, name='Nodes' )) # Calculate proper layout dimensions to ensure everything fits if node_x and node_y: # Get the actual data bounds min_x, max_x = min(node_x), max(node_x) min_y, max_y = min(node_y), max(node_y) # Calculate ranges x_range = max_x - min_x y_range = max_y - min_y # Add padding to ensure nothing is cut off (20% padding on each side) x_padding = x_range * 0.2 if x_range > 0 else 1 y_padding = y_range * 0.2 if y_range > 0 else 1 # Set axis ranges with padding x_axis_range = [min_x - x_padding, max_x + x_padding] y_axis_range = [min_y - y_padding, max_y + y_padding] # Compact but sufficient sizing width = min(1400, max(800, int(x_range * 80 + 400))) # Cap max width height = min(900, max(500, int(y_range * 40 + 300))) # Cap max height else: width, height = 800, 500 x_axis_range = None y_axis_range = None # Update layout for compact horizontal tree with proper bounds fig.update_layout( title=dict( text=f"Compact Horizontal Phylogenetic Tree (ML-Based)
" f"Query: {self.query_id} | Similarity: {actual_percentage}% | " f"Matched: {len(matched_ids)}", x=0.5, font=dict(size=12) # Smaller title for compactness ), xaxis=dict( showgrid=False, gridcolor='lightgray', gridwidth=0.3, # Very thin grid lines zeroline=False, showticklabels=False, range=x_axis_range, # Set explicit range to prevent cutoff fixedrange=False, # Allow zooming if needed automargin=True # Automatically adjust margins ), yaxis=dict( showgrid=False, gridcolor='lightgray', gridwidth=0.3, # Very thin grid lines zeroline=False, showticklabels=False, range=y_axis_range, # Set explicit range to prevent cutoff fixedrange=False, # Allow zooming if needed automargin=True # Automatically adjust margins ), plot_bgcolor="white", paper_bgcolor="white", hovermode="closest", width=width, height=height, margin=dict(l=20, r=100, t=40, b=10), # Adequate margins, extra right margin for text autosize=False, # Don't auto-resize showlegend=True, legend=dict( x=1.02, # Position legend outside plot area y=1, xanchor='left', yanchor='top', bgcolor='rgba(255,255,255,0.8)', bordercolor='gray', borderwidth=1, font=dict(size=10) # Smaller legend font ) ) # Add comprehensive legend with smaller markers legend_elements = [ dict(name="Root", marker=dict(color=colors['root'], size=8)), dict(name="Normalized ML Groups", marker=dict(color=colors['normalized_ml_group'], size=8)), dict(name="Full ML Groups", marker=dict(color=colors['full_ml_group'], size=8)), dict(name="Genotypes", marker=dict(color=colors['genotype'], size=8)), dict(name="Query Sequence", marker=dict(color=colors['query_sequence'], size=10)), dict(name="Similar Sequences", marker=dict(color=colors['matched_sequence'], size=9)), dict(name="Representative Sequences", marker=dict(color=colors['representative_sequence'], size=8)), dict(name="Other Sequences", marker=dict(color=colors['other_sequence'], size=7)) ] for i, element in enumerate(legend_elements): fig.add_trace(go.Scatter( x=[None], y=[None], mode='markers', marker=element['marker'], name=element['name'], showlegend=True )) # Configure modebar for better user experience config = { 'displayModeBar': True, 'displaylogo': False, 'modeBarButtonsToRemove': ['select2d', 'lasso2d'], 'toImageButtonOptions': { 'format': 'png', 'filename': 'phylogenetic_tree', 'height': height, 'width': width, 'scale': 2 } } # Save outputs try: fig.write_html("phylogenetic_tree_normalized_horizontal.html", config=config) print("✓ Compact horizontal interactive tree saved as 'phylogenetic_tree_normalized_horizontal.html'") except Exception as e: print(f"Warning: Could not save HTML file: {e}") # Display the figure with config try: fig.show(config=config) except Exception as e: print(f"Warning: Could not display figure: {e}") return fig except Exception as e: print(f"Error creating compact horizontal interactive tree: {e}") return None def create_sequence_alignment(self, sequence_ids: List[str]) -> Optional[MultipleSeqAlignment]: try: print("🧬 Creating multiple sequence alignment...") # Get sequences sequences = [] for seq_id in sequence_ids: try: row = self.data[self.data['Accession Number'] == seq_id] if not row.empty: f_gene = str(row.iloc[0]['F-gene']) # Clean sequence (remove non-nucleotide characters) clean_seq = re.sub(r'[^ATGCN-]', '', f_gene.upper()) if len(clean_seq) > 10: # Minimum sequence length seq_record = SeqRecord(Seq(clean_seq), id=seq_id, description="") sequences.append(seq_record) except Exception as e: print(f"Warning: Skipping sequence {seq_id}: {e}") continue if len(sequences) < 2: print("❌ Need at least 2 valid sequences for alignment") return None # Simple alignment (you might want to use MUSCLE or CLUSTAL for better results) aligned_sequences = self._simple_alignment(sequences) print(f"✓ Alignment created with {len(aligned_sequences)} sequences") return MultipleSeqAlignment(aligned_sequences) except Exception as e: print(f"Error creating alignment: {e}") return None def _simple_alignment(self, sequences: List[SeqRecord]) -> List[SeqRecord]: try: # Find maximum length max_length = max(len(seq.seq) for seq in sequences) # Cap maximum length to prevent memory issues if max_length > 10000: max_length = 10000 print(f"Warning: Sequences truncated to {max_length} bp") # Pad sequences to same length aligned_sequences = [] for seq in sequences: seq_str = str(seq.seq)[:max_length] # Truncate if too long if len(seq_str) < max_length: # Pad with gaps at the end padded_seq = seq_str + '-' * (max_length - len(seq_str)) else: padded_seq = seq_str aligned_sequences.append(SeqRecord(Seq(padded_seq), id=seq.id, description=seq.description)) return aligned_sequences except Exception as e: print(f"Error in simple alignment: {e}") return sequences # Return original sequences as fallback def calculate_ml_distances(self, alignment: MultipleSeqAlignment) -> np.ndarray: try: print("📊 Calculating ML distances...") # Convert alignment to numeric matrix seq_matrix = self._alignment_to_matrix(alignment) n_sequences = len(alignment) if n_sequences == 0: return np.array([]) # Initialize distance matrix distance_matrix = np.zeros((n_sequences, n_sequences)) # Calculate pairwise ML distances for i in range(n_sequences): for j in range(i + 1, n_sequences): try: ml_distance = self._calculate_ml_distance_pair(seq_matrix[i], seq_matrix[j]) distance_matrix[i][j] = ml_distance distance_matrix[j][i] = ml_distance except Exception as e: print(f"Warning: Error calculating distance between sequences {i} and {j}: {e}") # Use maximum distance as fallback distance_matrix[i][j] = 1.0 distance_matrix[j][i] = 1.0 print("✓ ML distances calculated") return distance_matrix except Exception as e: print(f"Error calculating ML distances: {e}") return np.array([]) def _alignment_to_matrix(self, alignment: MultipleSeqAlignment) -> np.ndarray: try: # Nucleotide to number mapping nucleotide_map = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4, '-': 5} matrix = [] for record in alignment: sequence = str(record.seq).upper() numeric_seq = [nucleotide_map.get(nuc, 4) for nuc in sequence] matrix.append(numeric_seq) return np.array(matrix) except Exception as e: print(f"Error converting alignment to matrix: {e}") return np.array([]) def _calculate_ml_distance_pair(self, seq1: np.ndarray, seq2: np.ndarray) -> float: try: if len(seq1) == 0 or len(seq2) == 0: return 1.0 # Count differences (excluding gaps and N's) valid_positions = (seq1 < 4) & (seq2 < 4) # Exclude N's and gaps if np.sum(valid_positions) == 0: return 1.0 # Maximum distance if no valid comparisons differences = np.sum(seq1[valid_positions] != seq2[valid_positions]) total_valid = np.sum(valid_positions) if total_valid == 0: return 1.0 # Calculate proportion of differences p = differences / total_valid # Jukes-Cantor correction if p >= 0.75: return 1.0 # Maximum distance # JC distance formula: -3/4 * ln(1 - 4p/3) try: jc_distance = -0.75 * np.log(1 - (4 * p / 3)) return min(max(jc_distance, 0.0), 1.0) # Clamp between 0 and 1 except (ValueError, RuntimeWarning): return 1.0 # Return maximum distance if log calculation fails except Exception as e: return 1.0 # Return maximum distance on error def construct_ml_tree(self, alignment: MultipleSeqAlignment) -> Optional[BaseTree.Tree]: try: print("🌳 Constructing Maximum Likelihood tree...") # Calculate ML distance matrix distance_matrix = self.calculate_ml_distances(alignment) if distance_matrix.size == 0: return None # Create sequence names list sequence_names = [record.id for record in alignment] # Build tree using neighbor-joining on ML distances tree = self._build_nj_tree_from_distances(distance_matrix, sequence_names) # Optimize branch lengths using ML (with recursion protection) if tree: tree = self._optimize_branch_lengths_ml_safe(tree, alignment) print("✓ ML tree constructed successfully") return tree except Exception as e: print(f"Error constructing ML tree: {e}") return None def _build_nj_tree_from_distances(self, distance_matrix: np.ndarray, sequence_names: List[str]) -> Optional[BaseTree.Tree]: try: from Bio.Phylo.TreeConstruction import DistanceMatrix, DistanceTreeConstructor # Validate inputs if distance_matrix.shape[0] != len(sequence_names): print("Error: Distance matrix size doesn't match sequence names") return None # Convert numpy array to Bio.Phylo distance matrix format matrix_data = [] for i in range(len(sequence_names)): row = [] for j in range(i + 1): if i == j: row.append(0.0) else: # Ensure distance is valid dist = float(distance_matrix[i][j]) if np.isnan(dist) or np.isinf(dist): dist = 1.0 row.append(max(0.0, dist)) # Ensure non-negative matrix_data.append(row) # Create DistanceMatrix object dm = DistanceMatrix(names=sequence_names, matrix=matrix_data) # Build tree using Neighbor-Joining constructor = DistanceTreeConstructor() tree = constructor.nj(dm) # Validate tree structure if tree and self._validate_tree_structure(tree): return tree else: print("Warning: Tree structure validation failed") return tree # Return anyway, might still be usable except Exception as e: print(f"Error building NJ tree: {e}") return None def _validate_tree_structure(self, tree: BaseTree.Tree, max_depth: int = 100) -> bool: try: visited = set() def check_node(node, depth=0): if depth > max_depth: return False # Check for circular references node_id = id(node) if node_id in visited: return False visited.add(node_id) # Check children for child in getattr(node, 'clades', []): if not check_node(child, depth + 1): return False return True return check_node(tree.root if hasattr(tree, 'root') else tree) except Exception: return False def _optimize_branch_lengths_ml_safe(self, tree: BaseTree.Tree, alignment: MultipleSeqAlignment) -> BaseTree.Tree: try: print("🔧 Optimizing branch lengths with ML...") # Set recursion limit temporarily old_limit = sys.getrecursionlimit() sys.setrecursionlimit(1000) try: # Convert alignment to matrix seq_matrix = self._alignment_to_matrix(alignment) if seq_matrix.size == 0: print("Warning: Empty sequence matrix, skipping optimization") return tree # Get all internal and external nodes with depth tracking all_clades = self._get_clades_safe(tree) # Simple branch length optimization for clade in all_clades: if hasattr(clade, 'branch_length') and clade.branch_length is not None: try: # Calculate optimal branch length based on likelihood optimal_length = self._calculate_optimal_branch_length_safe(clade, seq_matrix) clade.branch_length = max(optimal_length, 0.001) # Minimum branch length except Exception as e: print(f"Warning: Failed to optimize branch for clade: {e}") # Keep original branch length pass print("✓ Branch lengths optimized") finally: # Restore original recursion limit sys.setrecursionlimit(old_limit) return tree except Exception as e: print(f"Warning: Branch length optimization failed: {e}") return tree def _get_clades_safe(self, tree: BaseTree.Tree, max_depth: int = 50) -> List: clades = [] visited = set() def traverse_node(node, depth=0): if depth > max_depth or id(node) in visited: return visited.add(id(node)) clades.append(node) # Traverse children safely try: children = getattr(node, 'clades', []) for child in children: traverse_node(child, depth + 1) except Exception: pass # Skip problematic nodes try: root = tree.root if hasattr(tree, 'root') else tree traverse_node(root) except Exception as e: print(f"Warning: Tree traversal error: {e}") return clades def _calculate_optimal_branch_length_safe(self, clade, seq_matrix: np.ndarray) -> float: try: # Simplified ML branch length estimation if not hasattr(clade, 'branch_length') or clade.branch_length is None: return 0.1 # Default branch length current_length = float(clade.branch_length) # Validate current length if np.isnan(current_length) or np.isinf(current_length) or current_length <= 0: return 0.1 # Simple optimization based on sequence characteristics if hasattr(clade, 'name') and clade.name: # For terminal nodes return min(max(current_length * 0.9, 0.001), 1.0) else: # For internal nodes return min(max(current_length * 1.1, 0.001), 1.0) except Exception: return 0.1 # Safe default def calculate_ml_likelihood_safe(self, tree: BaseTree.Tree, alignment: MultipleSeqAlignment) -> float: try: print("📈 Calculating tree likelihood...") seq_matrix = self._alignment_to_matrix(alignment) if seq_matrix.size == 0: return -np.inf # Simplified likelihood calculation using Jukes-Cantor model total_log_likelihood = 0.0 # For each site in the alignment (sample subset to avoid memory issues) n_sites = min(seq_matrix.shape[1], 1000) # Limit sites for performance for site in range(0, n_sites, max(1, n_sites // 100)): # Sample sites try: site_pattern = seq_matrix[:, site] # Skip sites with gaps or N's valid_positions = site_pattern < 4 if np.sum(valid_positions) < 2: continue # Calculate likelihood for this site pattern site_likelihood = self._calculate_site_likelihood_safe(tree, site_pattern) if site_likelihood > 0: total_log_likelihood += np.log(site_likelihood) except Exception as e: print(f"Warning: Error processing site {site}: {e}") continue print(f"✓ Tree likelihood calculated: {total_log_likelihood:.2f}") return total_log_likelihood except Exception as e: print(f"Error calculating likelihood: {e}") return -np.inf def _calculate_site_likelihood_safe(self, tree: BaseTree.Tree, site_pattern: np.ndarray) -> float: try: # Count nucleotide frequencies at this site valid_nucs = site_pattern[site_pattern < 4] if len(valid_nucs) == 0: return 1.0 # Simple likelihood based on nucleotide diversity unique_nucs = len(np.unique(valid_nucs)) total_nucs = len(valid_nucs) # Higher diversity = lower likelihood of simple evolution diversity_factor = unique_nucs / 4.0 # Normalize by 4 nucleotides # Simple likelihood model likelihood = np.exp(-diversity_factor * total_nucs * 0.1) return max(likelihood, 1e-10) # Avoid zero likelihood except Exception: return 1e-10 # Safe fallback def perform_ml_analysis_safe(self, matched_ids: List[str]) -> Dict: try: print("\n🧬 PERFORMING MAXIMUM LIKELIHOOD ANALYSIS") print("="*50) # Include query sequence in analysis all_sequences = [self.query_id] + [seq_id for seq_id in matched_ids if seq_id != self.query_id] # Limit number of sequences to prevent memory issues if len(all_sequences) > 20: print(f"Warning: Limiting analysis to 20 sequences (had {len(all_sequences)})") all_sequences = all_sequences[:20] if len(all_sequences) < 3: print("❌ Need at least 3 sequences for ML analysis") return {} # Step 1: Create multiple sequence alignment alignment = self.create_sequence_alignment(all_sequences) if not alignment: return {} # Step 2: Calculate ML distances distance_matrix = self.calculate_ml_distances(alignment) if distance_matrix.size == 0: return {} # Step 3: Construct ML tree ml_tree = self.construct_ml_tree(alignment) if not ml_tree: return {} # Step 4: Calculate tree likelihood (safely) log_likelihood = self.calculate_ml_likelihood_safe(ml_tree, alignment) # Step 5: Prepare results ml_results = { 'tree': ml_tree, 'alignment': alignment, 'distance_matrix': distance_matrix, 'log_likelihood': log_likelihood, 'sequence_count': len(all_sequences), 'alignment_length': len(alignment[0]) if alignment else 0 } print(f"✅ ML analysis completed successfully") print(f" Sequences analyzed: {len(all_sequences)}") print(f" Alignment length: {ml_results['alignment_length']}") print(f" Log-likelihood: {log_likelihood:.2f}") return ml_results except Exception as e: print(f"❌ ML analysis failed: {e}") import traceback traceback.print_exc() return {} def build_tree_structure_with_ml_safe(self, matched_ids: List[str]) -> Dict: try: print("🌳 Building ML-enhanced tree structure...") # Perform ML analysis first ml_results = self.perform_ml_analysis_safe(matched_ids) # Build the original hierarchical structure tree_structure = self.build_tree_structure(matched_ids) # Enhance with ML information if ml_results and 'tree' in ml_results: tree_structure['ml_analysis'] = { 'log_likelihood': ml_results['log_likelihood'], 'sequence_count': ml_results['sequence_count'], 'alignment_length': ml_results['alignment_length'], 'ml_tree_available': True } # Store ML tree for later use self.ml_tree = ml_results['tree'] self.ml_alignment = ml_results.get('alignment') print("✓ Tree structure enhanced with ML analysis") else: tree_structure['ml_analysis'] = { 'ml_tree_available': False, 'error': 'ML analysis failed' } print("⚠️ ML analysis failed, using standard tree structure") return tree_structure except Exception as e: print(f"Error building ML-enhanced tree structure: {e}") # Fallback to original method try: return self.build_tree_structure(matched_ids) except Exception as e2: print(f"Fallback also failed: {e2}") return {'error': 'Both ML and standard tree construction failed'} def _print_tree_topology(self, tree, max_depth=3, current_depth=0, prefix=""): if current_depth > max_depth: return try: # Get all clades at current level clades = list(tree.find_clades()) for i, clade in enumerate(clades[:5]): # Limit to first 5 for readability branch_info = "" if clade.branch_length is not None: branch_info = f" (len: {clade.branch_length:.4f})" if clade.is_terminal(): node_name = clade.name or "Terminal" print(f" {prefix}├── {node_name}{branch_info}") else: node_name = clade.name or f"Internal_{i}" print(f" {prefix}├── {node_name}{branch_info}") if current_depth < max_depth - 1 and not clade.is_terminal(): # Show children (simplified) children = list(clade.find_clades()) if len(children) > 1: for j, child in enumerate(children[1:3]): # Show max 2 children child_name = child.name or f"Node_{j}" child_branch = f" (len: {child.branch_length:.4f})" if child.branch_length else "" print(f" {prefix}│ ├── {child_name}{child_branch}") except Exception as e: print(f" Error displaying topology: {e}") def main(): print("\n" + "="*70) print("🧬 PHYLOGENETIC TREE ANALYZER - ADVANCED ML-BASED ANALYSIS") print("="*70) print("Version 2.0 | AI-Enhanced Similarity Matching") print("Interactive Visualization with Variable Line Lengths") print("="*70) # Initialize the analyzer analyzer = PhylogeneticTreeAnalyzer() try: # Step 1: Load data while True: data_file = "f cleaned.csv" if not data_file: print("❌ Please provide a file path.") continue if not Path(data_file).exists(): print(f"❌ File not found: {data_file}") continue if analyzer.load_data(data_file): break else: print("❌ Failed to load data. Please check file format.") continue # Step 2: Train AI model automatically print("\n⏳ Training AI model... This may take a few moments.", flush=True) start_time = time.time() if analyzer.train_ai_model(): elapsed = time.time() - start_time print(f"✅ AI model training completed in {elapsed:.1f} seconds", flush=True) else: print("⚠️ AI model training failed, continuing with basic analysis", flush=True) # Step 3: Get query sequence while True: print("\n🔍 QUERY SEQUENCE INPUT:") print(" You can provide:") print(" 1. Accession Number (e.g., 'MH087032') - from your dataset") print(" 2. ANY F-gene nucleotide sequence (A, T, G, C)") print(" 3. Novel sequences will be compared against your dataset") print(" Note: Minimum sequence length is 10 nucleotides") query_input = input("\nEnter query sequence or ID: ").strip() if not query_input: print("❌ Please provide a query sequence or ID.") continue if analyzer.find_query_sequence(query_input): break else: retry = input("❌ Invalid input. Try again? (y/n): ").strip().lower() if retry != 'y': print("👋 Analysis cancelled.") return # Step 4: Set similarity percentage while True: try: print(f"\n📊 SIMILARITY THRESHOLD:") print(f" - Higher values (90-99%): Find very similar sequences") print(f" - Lower values (70-89%): Find more distantly related sequences") similarity_input = input(f"Enter target similarity percentage (1-99) [85]: ").strip() if not similarity_input: target_percentage = 85.0 # Lowered default for novel sequences else: target_percentage = float(similarity_input) if not (1 <= target_percentage <= 99): print("❌ Please enter a percentage between 1 and 99.") continue analyzer.matching_percentage = target_percentage break except ValueError: print("❌ Please enter a valid number.") continue # Step 5: Find similar sequences print(f"\n⏳ Analyzing sequences for {target_percentage}% similarity...") start_time = time.time() matched_ids, actual_percentage = analyzer.find_similar_sequences(target_percentage) if not matched_ids: print(f"❌ No similar sequences found at {target_percentage}% similarity.") print("💡 Try lowering the similarity percentage (e.g., 70-80%) to find more distant matches.") return analyzer.matched_sequences = matched_ids analyzer.actual_percentage = actual_percentage elapsed = time.time() - start_time print(f"✅ Similarity analysis completed in {elapsed:.1f} seconds") # Step 6: Build tree structure print("\n⏳ Building phylogenetic tree structure...") start_time = time.time() tree_structure = analyzer.build_tree_structure_with_ml_safe(matched_ids) if not tree_structure: print("❌ Failed to build tree structure.") return elapsed = time.time() - start_time print(f"✅ Tree structure built in {elapsed:.1f} seconds") # Step 7: Create visualization and save HTML print("\n⏳ Creating interactive visualization...") start_time = time.time() fig = analyzer.create_interactive_tree(matched_ids, actual_percentage) if fig: elapsed = time.time() - start_time print(f"✅ Visualization created in {elapsed:.1f} seconds") # Save the interactive HTML file html_filename = "phylogenetic_tree_interactive.html" fig.write_html(html_filename) print(f"📄 Interactive HTML saved: {html_filename}") print(f"\n🎉 Analysis completed successfully!") print(f" Query ID: {analyzer.query_id}") print(f" Query sequence length: {len(analyzer.query_sequence)} nucleotides") print(f" Similar sequences found: {len(matched_ids)}") print(f" Actual similarity percentage: {actual_percentage:.1f}%") print(f" HTML file generated: {html_filename}") else: print("❌ Visualization creation failed.") return except KeyboardInterrupt: print(f"\n\n⚠️ Analysis interrupted by user.") sys.exit(1) except Exception as e: print(f"\n❌ An error occurred during analysis: {e}") print(f"Please check your input data and try again.") sys.exit(1) def command_line_interface(): parser = argparse.ArgumentParser( description="Advanced Phylogenetic Tree Analyzer with AI-enhanced similarity matching", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # %(prog)s -d data.csv -q MH087032 -s 95 # %(prog)s -d data.csv -q MH087032 -s 90 --no-ai --batch query1,query2,query3 """ ) parser.add_argument('-d', '--data', required=True, help='Path to CSV data file') parser.add_argument('-q', '--query', required=True, help='Query sequence ID or nucleotide sequence') parser.add_argument('-s', '--similarity', type=float, default=95.0, help='Target similarity percentage (70-99, default: 95)') parser.add_argument('--no-ai', action='store_true', help='Skip AI model training') parser.add_argument('--batch', help='Comma-separated list of query IDs for batch processing') parser.add_argument('--output-dir', default='.', help='Output directory for results') parser.add_argument('--save-json', action='store_true', help='Save detailed results to JSON') args = parser.parse_args() # Validate arguments if not (70 <= args.similarity <= 99): print("❌ Similarity percentage must be between 70 and 99.") sys.exit(1) if not Path(args.data).exists(): print(f"❌ Data file not found: {args.data}") sys.exit(1) # Initialize analyzer analyzer = PhylogeneticTreeAnalyzer() # Load data if not analyzer.load_data(args.data): print("❌ Failed to load data.") sys.exit(1) # Train AI model (unless disabled) if not args.no_ai: print("\n⏳ Training AI model... This may take a few moments.", flush=True) start_time = time.time() if analyzer.train_ai_model(): elapsed = time.time() - start_time print(f"✅ AI model training completed in {elapsed:.1f} seconds", flush=True) else: print("⚠️ AI model training failed, continuing with basic analysis", flush=True) # Process queries queries = args.batch.split(',') if args.batch else [args.query] for query in queries: query = query.strip() print(f"\n🔍 Processing: {query}") if analyzer.find_query_sequence(query): matched_ids, actual_percentage = analyzer.find_similar_sequences(args.similarity) if matched_ids: analyzer.build_tree_structure_with_ml_safe(matched_ids) fig = analyzer.create_interactive_tree(matched_ids, actual_percentage) if fig: # Save the interactive HTML file html_filename = f"phylogenetic_tree_{query.replace('/', '_')}_interactive.html" fig.write_html(html_filename) print(f"📄 Interactive HTML saved: {html_filename}") print(f"✅ Analysis completed for {query}") else: print(f"❌ No similar sequences found for {query}") else: print(f"❌ Query not found: {query}") if __name__ == "__main__": try: main() except KeyboardInterrupt: print(f"\n\n👋 Goodbye!") sys.exit(0) except Exception as e: print(f"\n❌ Unexpected error: {e}") sys.exit(1) #KR815908