Spaces:
No application file
No application file
| # -*- coding: utf-8 -*- | |
| """ML simplified tree.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1LiDjip-h70ilIex9PedpWCZARWglija7 | |
| """ | |
| # Commented out IPython magic to ensure Python compatibility. | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import plotly.offline as pyo | |
| from plotly.subplots import make_subplots | |
| from Bio import Phylo, SeqIO, AlignIO | |
| from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor | |
| from Bio.Align import MultipleSeqAlignment | |
| from Bio.Seq import Seq | |
| from Bio.SeqRecord import SeqRecord | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import LabelEncoder | |
| import warnings | |
| import os | |
| import sys | |
| from typing import Dict, List, Tuple, Optional, Any | |
| import json | |
| import re | |
| from scipy.optimize import minimize | |
| from scipy.spatial.distance import pdist, squareform | |
| from Bio.Phylo import BaseTree | |
| import itertools | |
| from collections import defaultdict, deque | |
| import argparse | |
| import time | |
| from pathlib import Path | |
| warnings.filterwarnings('ignore') | |
| class PhylogeneticTreeAnalyzer: | |
| def __init__(self): | |
| self.data = None | |
| self.query_sequence = None | |
| self.query_id = None | |
| self.matching_percentage = 95.0 | |
| self.actual_percentage = None | |
| self.matched_sequences = [] | |
| self.tree_structure = {} | |
| self.similarity_scores = {} | |
| self.ai_model = None | |
| self.label_encoder = LabelEncoder() | |
| # ML-specific attributes | |
| self.ml_tree = None | |
| self.ml_alignment = None | |
| self.ml_results = {} | |
| self.horizontal_line_tracker = [] # Track horizontal lines with verticals | |
| self.query_ml_group = None # Track which ML group contains the query | |
| self.base_horizontal_length = 1.2 # Base length for horizontal lines | |
| def load_data(self, data_file: str): | |
| try: | |
| self.data = pd.read_csv(data_file) | |
| # required_columns = ['Accession Number', 'ML', 'Genotype', 'Host', | |
| # 'Country', 'Isolate', 'Year', 'F-gene'] | |
| # missing_columns = [col for col in self.data.columns if col not in required_columns] # Corrected check for missing columns | |
| # if missing_columns: | |
| # print(f"Error: Missing required columns: {missing_columns}") | |
| # return False | |
| print(f"✓ Data loaded successfully: {len(self.data)} sequences") | |
| print(f"✓ ML Groups found: {self.data['ML'].nunique()}") | |
| print(f"✓ Genotypes found: {self.data['Genotype'].nunique()}") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading data: {e}") | |
| return False | |
| def calculate_f_gene_similarity(self, seq1: str, seq2: str) -> float: | |
| try: | |
| # Handle empty or None sequences | |
| if not seq1 or not seq2: | |
| return 0.0 | |
| # Convert to uppercase and remove non-nucleotide characters | |
| seq1 = re.sub(r'[^ATGC]', '', str(seq1).upper()) | |
| seq2 = re.sub(r'[^ATGC]', '', str(seq2).upper()) | |
| if len(seq1) == 0 or len(seq2) == 0: | |
| return 0.0 | |
| # Use k-mer analysis for similarity calculation | |
| k = 5 # 5-mer analysis | |
| kmers1 = set([seq1[i:i+k] for i in range(len(seq1)-k+1) if len(seq1[i:i+k]) == k]) | |
| kmers2 = set([seq2[i:i+k] for i in range(len(seq2)-k+1) if len(seq2[i:i+k]) == k]) | |
| if len(kmers1) == 0 and len(kmers2) == 0: | |
| return 100.0 | |
| elif len(kmers1) == 0 or len(kmers2) == 0: | |
| return 0.0 | |
| # Calculate Jaccard similarity | |
| intersection = len(kmers1.intersection(kmers2)) | |
| union = len(kmers1.union(kmers2)) | |
| similarity = (intersection / union) * 100 if union > 0 else 0.0 | |
| return round(similarity, 2) | |
| except Exception as e: | |
| print(f"Error calculating similarity: {e}") | |
| return 0.0 | |
| def train_ai_model(self): | |
| try: | |
| # Skip training if insufficient data | |
| if len(self.data) < 10: # Require minimum 10 samples | |
| print("⚠️ Insufficient data to train AI model (min 10 samples required)", flush=True) | |
| return False | |
| print("🤖 Training AI model for sequence analysis...", flush=True) | |
| # Prepare features from F-gene sequences | |
| f_gene_sequences = self.data['F-gene'].fillna('').astype(str) | |
| # Create k-mer features (3-mers to 6-mers) | |
| features = [] | |
| for seq in f_gene_sequences: | |
| seq_clean = re.sub(r'[^ATGC]', '', seq.upper()) | |
| if len(seq_clean) < 3: | |
| features.append([0] * 100) # Placeholder for short sequences | |
| continue | |
| feature_vector = [] | |
| # 3-mers | |
| kmers_3 = [seq_clean[i:i+3] for i in range(len(seq_clean)-2)] | |
| kmer_counts_3 = {kmer: kmers_3.count(kmer) for kmer in set(kmers_3)} | |
| # 4-mers | |
| kmers_4 = [seq_clean[i:i+4] for i in range(len(seq_clean)-3)] | |
| kmer_counts_4 = {kmer: kmers_4.count(kmer) for kmer in set(kmers_4)} | |
| # Create feature vector (top 50 3-mers + top 50 4-mers) | |
| all_3mers = [''.join(p) for p in __import__('itertools').product('ATGC', repeat=3)] | |
| all_4mers = [''.join(p) for p in __import__('itertools').product('ATGC', repeat=4)] | |
| feature_vector.extend([kmer_counts_3.get(kmer, 0) for kmer in all_3mers[:50]]) | |
| feature_vector.extend([kmer_counts_4.get(kmer, 0) for kmer in all_4mers[:50]]) | |
| features.append(feature_vector) | |
| # Prepare target labels (ML groups) | |
| targets = self.label_encoder.fit_transform(self.data['ML'].fillna('Unknown')) | |
| # Skip if only 1 class | |
| if len(np.unique(targets)) < 2: | |
| print("⚠️ Need at least 2 distinct classes for training", flush=True) | |
| return False | |
| # Train Random Forest model | |
| X = np.array(features) | |
| y = targets | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| self.ai_model = RandomForestClassifier(n_estimators=100, random_state=42) | |
| self.ai_model.fit(X_train, y_train) | |
| # Calculate accuracy | |
| accuracy = self.ai_model.score(X_test, y_test) | |
| print(f"✓ AI model trained successfully with accuracy: {accuracy:.2%}", flush=True) | |
| return True | |
| except Exception as e: | |
| print(f"🚨 CRITICAL training error: {e}", flush=True) | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def find_query_sequence(self, query_input: str) -> bool: | |
| """ | |
| Modified to accept any sequence input, not just those existing in the dataset. | |
| """ | |
| try: | |
| # Check if input is an accession number from the dataset | |
| if query_input in self.data['Accession Number'].values: | |
| self.query_id = query_input | |
| query_row = self.data[self.data['Accession Number'] == query_input].iloc[0] | |
| self.query_sequence = query_row['F-gene'] | |
| print(f"✓ Query sequence found by ID: {query_input}") | |
| return True | |
| # Check if input is a nucleotide sequence | |
| query_clean = re.sub(r'[^ATGC]', '', str(query_input).upper()) | |
| # Accept any sequence with reasonable length (even short ones for testing) | |
| if len(query_clean) >= 10: # Minimum sequence length (reduced from 50) | |
| # For sequences not in dataset, create a unique identifier | |
| if query_input not in self.data['Accession Number'].values: | |
| # Generate a unique query ID for novel sequences | |
| self.query_id = f"QUERY_{hash(query_clean) % 100000:05d}" | |
| self.query_sequence = query_clean | |
| print(f"✓ Novel query sequence accepted with ID: {self.query_id}") | |
| print(f" Sequence length: {len(query_clean)} nucleotides") | |
| return True | |
| else: | |
| # If somehow it matches an accession but wasn't caught above | |
| self.query_id = query_input | |
| self.query_sequence = query_clean | |
| print(f"✓ Query sequence processed: {query_input}") | |
| return True | |
| # If sequence is too short or invalid | |
| if len(query_clean) < 10: | |
| print(f"✗ Query sequence too short. Minimum length: 10 nucleotides (provided: {len(query_clean)})") | |
| else: | |
| print(f"✗ Invalid sequence format. Please provide nucleotides (A, T, G, C) or valid accession number") | |
| return False | |
| except Exception as e: | |
| print(f"Error processing query sequence: {e}") | |
| return False | |
| def find_similar_sequences(self, target_percentage: float) -> Tuple[List[str], float]: | |
| """ | |
| Modified to work with any query sequence, including novel ones not in the dataset. | |
| """ | |
| try: | |
| print(f"🔍 Finding sequences with {target_percentage}% similarity to query...") | |
| similarities = [] | |
| # Calculate similarity between query and all sequences in dataset | |
| for idx, row in self.data.iterrows(): | |
| # Skip if this is the same sequence (only relevant for existing accession numbers) | |
| if hasattr(self, 'query_id') and row['Accession Number'] == self.query_id: | |
| continue | |
| try: | |
| similarity = self.calculate_f_gene_similarity(self.query_sequence, row['F-gene']) | |
| similarities.append({ | |
| 'id': row['Accession Number'], | |
| 'similarity': similarity, | |
| 'ml': row['ML'] if 'ML' in row else 'Unknown', | |
| 'genotype': row['Genotype'] if 'Genotype' in row else 'Unknown' | |
| }) | |
| except Exception as seq_error: | |
| print(f"⚠ Skipping sequence {row['Accession Number']}: {seq_error}") | |
| continue | |
| if not similarities: | |
| print("❌ No valid sequences found for comparison") | |
| return [], target_percentage | |
| # Sort by similarity (highest first) | |
| similarities.sort(key=lambda x: x['similarity'], reverse=True) | |
| # Find sequences within target percentage range (±2%) | |
| target_range = 2.0 | |
| candidates = [s for s in similarities | |
| if abs(s['similarity'] - target_percentage) <= target_range] | |
| if not candidates: | |
| # If no exact matches, find sequences with closest similarity | |
| closest_sim = min(similarities, key=lambda x: abs(x['similarity'] - target_percentage)) | |
| actual_percentage = closest_sim['similarity'] | |
| # Get sequences within ±1% of the closest similarity | |
| candidates = [s for s in similarities | |
| if abs(s['similarity'] - actual_percentage) <= 1.0] | |
| print(f"⚠ No sequences found at exactly {target_percentage}%. Using closest: {actual_percentage:.1f}%") | |
| else: | |
| actual_percentage = target_percentage | |
| # Limit results to prevent overwhelming visualization (optional) | |
| max_results = 50 # Adjust as needed | |
| if len(candidates) > max_results: | |
| candidates = candidates[:max_results] | |
| print(f"⚠ Limited results to top {max_results} matches for better visualization") | |
| # Store similarity scores for later use | |
| self.similarity_scores = {} # Reset similarity scores | |
| for candidate in candidates: | |
| self.similarity_scores[candidate['id']] = candidate['similarity'] | |
| matched_ids = [c['id'] for c in candidates] | |
| # Show some statistics | |
| if similarities: | |
| max_sim = max(similarities, key=lambda x: x['similarity'])['similarity'] | |
| min_sim = min(similarities, key=lambda x: x['similarity'])['similarity'] | |
| avg_sim = sum(s['similarity'] for s in similarities) / len(similarities) | |
| print(f"✓ Found {len(matched_ids)} sequences at ~{actual_percentage:.1f}% similarity") | |
| print(f" Similarity range in dataset: {min_sim:.1f}% - {max_sim:.1f}% (avg: {avg_sim:.1f}%)") | |
| return matched_ids, actual_percentage | |
| except Exception as e: | |
| print(f"Error finding similar sequences: {e}") | |
| return [], target_percentage | |
| def build_tree_structure(self, matched_ids: List[str]) -> Dict: | |
| try: | |
| print("🌳 Building normalized horizontal tree structure...") | |
| # Initialize tree structure | |
| tree_structure = { | |
| 'root': { | |
| 'name': 'Root', | |
| 'type': 'root', | |
| 'children': {}, | |
| 'x': 0, | |
| 'y': 0, | |
| 'has_vertical_attachment': False, | |
| 'extension_level': 0 | |
| } | |
| } | |
| # Group sequences by ML and Genotype | |
| ml_groups = {} | |
| for idx, row in self.data.iterrows(): | |
| ml_group = row['ML'] | |
| genotype = row['Genotype'] | |
| seq_id = row['Accession Number'] | |
| if ml_group not in ml_groups: | |
| ml_groups[ml_group] = {} | |
| if genotype not in ml_groups[ml_group]: | |
| ml_groups[ml_group][genotype] = [] | |
| ml_groups[ml_group][genotype].append({ | |
| 'id': seq_id, | |
| 'data': row.to_dict(), | |
| 'is_query': seq_id == self.query_id, | |
| 'is_matched': seq_id in matched_ids, | |
| 'similarity': self.similarity_scores.get(seq_id, 0.0) | |
| }) | |
| # Normalize ML group names and organize | |
| normalized_ml_groups = self._normalize_ml_groups(ml_groups) | |
| # Build normalized ML level - horizontal layout with progressive extensions | |
| self._build_normalized_ml_nodes(tree_structure, normalized_ml_groups, matched_ids) | |
| self.tree_structure = tree_structure | |
| print(f"✓ Normalized horizontal tree structure built") | |
| return tree_structure | |
| except Exception as e: | |
| print(f"Error building tree structure: {e}") | |
| return {} | |
| def _normalize_ml_groups(self, ml_groups: Dict) -> Dict: | |
| """Normalize ML group names and organize hierarchically""" | |
| try: | |
| normalized_groups = {} | |
| for ml_name, genotypes in ml_groups.items(): | |
| # Extract base ML name | |
| if ml_name.startswith('UNCL'): | |
| base_ml = 'UNCL' | |
| elif '.' in ml_name and any(char.isdigit() for char in ml_name): | |
| # For names like XII.1.2, XII.1, etc., extract the base (XII) | |
| base_ml = ml_name.split('.')[0] | |
| else: | |
| base_ml = ml_name | |
| # Initialize normalized group structure | |
| if base_ml not in normalized_groups: | |
| normalized_groups[base_ml] = { | |
| 'full_ml_groups': {}, | |
| 'representative_sequences': [], | |
| 'has_special_sequences': False | |
| } | |
| # Check if this ML group has query or matched sequences | |
| has_special = any( | |
| any(seq['is_query'] or seq['is_matched'] for seq in sequences) | |
| for sequences in genotypes.values() | |
| ) | |
| if has_special: | |
| normalized_groups[base_ml]['has_special_sequences'] = True | |
| normalized_groups[base_ml]['full_ml_groups'][ml_name] = genotypes | |
| else: | |
| # Add as representative (limit to 2 representatives) | |
| if len(normalized_groups[base_ml]['representative_sequences']) < 2: | |
| # Get 1-2 representative sequences from this ML group | |
| for genotype, sequences in list(genotypes.items())[:2]: | |
| if len(normalized_groups[base_ml]['representative_sequences']) < 2: | |
| normalized_groups[base_ml]['representative_sequences'].extend(sequences[:1]) | |
| return normalized_groups | |
| except Exception as e: | |
| print(f"Error normalizing ML groups: {e}") | |
| return {} | |
| def _build_normalized_ml_nodes(self, tree_structure: Dict, normalized_ml_groups: Dict, matched_ids: List[str]): | |
| """Build normalized ML nodes with equal spacing and progressive horizontal extensions""" | |
| try: | |
| # Reset horizontal line tracker | |
| self.horizontal_line_tracker = [] | |
| # Identify which ML group contains the query | |
| self._identify_query_ml_group(normalized_ml_groups) | |
| # Calculate equal spacing for all ML groups | |
| ml_positions = self._calculate_dynamic_ml_positions(normalized_ml_groups) | |
| # Mark root as having vertical attachment if it has multiple children | |
| root_has_vertical = len(normalized_ml_groups) > 1 | |
| tree_structure['root']['has_vertical_attachment'] = root_has_vertical | |
| for ml_idx, (base_ml, ml_data) in enumerate(normalized_ml_groups.items()): | |
| y_pos = ml_positions[ml_idx] | |
| # Determine if this ML node will have vertical attachments | |
| has_vertical = ml_data['has_special_sequences'] and len(ml_data['full_ml_groups']) > 1 | |
| # Check if this ML group contains the query | |
| contains_query = (base_ml == self.query_ml_group) | |
| # Calculate horizontal line length based on connections and query presence | |
| horizontal_length = self._determine_horizontal_line_length( | |
| 'normalized_ml_group', has_vertical, contains_query | |
| ) | |
| x_pos = horizontal_length | |
| # Create normalized ML node | |
| tree_structure['root']['children'][base_ml] = { | |
| 'name': base_ml, | |
| 'type': 'normalized_ml_group', | |
| 'children': {}, | |
| 'x': x_pos, | |
| 'y': y_pos, | |
| 'has_special_sequences': ml_data['has_special_sequences'], | |
| 'has_vertical_attachment': has_vertical, | |
| 'horizontal_line_length': horizontal_length, | |
| 'contains_query': contains_query | |
| } | |
| if ml_data['has_special_sequences']: | |
| # Build full ML nodes for groups with special sequences | |
| self._build_full_ml_nodes( | |
| tree_structure['root']['children'][base_ml], | |
| ml_data['full_ml_groups'], | |
| y_pos, | |
| matched_ids, | |
| x_pos | |
| ) | |
| else: | |
| # Add representative sequences directly | |
| self._add_representative_sequences( | |
| tree_structure['root']['children'][base_ml], | |
| ml_data['representative_sequences'], | |
| y_pos, | |
| x_pos | |
| ) | |
| except Exception as e: | |
| print(f"Error building normalized ML nodes: {e}") | |
| def _calculate_dynamic_ml_positions(self, normalized_ml_groups: Dict) -> List[float]: | |
| """Calculate equal Y positions for all ML groups regardless of content""" | |
| try: | |
| ml_count = len(normalized_ml_groups) | |
| if ml_count == 0: | |
| return [] | |
| if ml_count == 1: | |
| return [0.0] | |
| # Equal spacing between all ML nodes | |
| total_spacing = (ml_count - 1) * 2.0 # 2.0 units between each ML node | |
| start_y = -total_spacing / 2 | |
| positions = [] | |
| for i in range(ml_count): | |
| positions.append(start_y + i * 2.0) | |
| return positions | |
| except Exception as e: | |
| print(f"Error calculating dynamic positions: {e}") | |
| return list(range(len(normalized_ml_groups))) | |
| def _build_full_ml_nodes(self, normalized_ml_node: Dict, full_ml_groups: Dict, base_y: float, matched_ids: List[str], parent_x: float): | |
| """Build full ML nodes with genotypes for groups containing special sequences""" | |
| try: | |
| # Calculate equal positions for full ML groups | |
| full_ml_positions = self._calculate_full_ml_positions(full_ml_groups, base_y) | |
| for ml_idx, (full_ml_name, genotypes) in enumerate(full_ml_groups.items()): | |
| y_pos = full_ml_positions[ml_idx] | |
| # Determine if this full ML node will have vertical attachments | |
| special_genotypes_count = sum(1 for genotype, sequences in genotypes.items() | |
| if any(seq['is_query'] or seq['is_matched'] for seq in sequences)) | |
| has_vertical = special_genotypes_count > 1 | |
| # Check if this full ML group contains the query | |
| contains_query = any( | |
| any(seq['is_query'] for seq in sequences) | |
| for sequences in genotypes.values() | |
| ) | |
| # Calculate horizontal line length | |
| horizontal_length = self._determine_horizontal_line_length( | |
| 'full_ml_group', has_vertical, contains_query | |
| ) | |
| x_pos = parent_x + horizontal_length | |
| # Create full ML node | |
| normalized_ml_node['children'][full_ml_name] = { | |
| 'name': full_ml_name, | |
| 'type': 'full_ml_group', | |
| 'children': {}, | |
| 'x': x_pos, | |
| 'y': y_pos, | |
| 'sequences_count': sum(len(seqs) for seqs in genotypes.values()), | |
| 'has_vertical_attachment': has_vertical, | |
| 'horizontal_line_length': horizontal_length, | |
| 'contains_query': contains_query | |
| } | |
| # Build genotype nodes | |
| self._build_genotype_nodes( | |
| normalized_ml_node['children'][full_ml_name], | |
| genotypes, | |
| y_pos, | |
| matched_ids, | |
| x_pos | |
| ) | |
| except Exception as e: | |
| print(f"Error building full ML nodes: {e}") | |
| def _calculate_full_ml_positions(self, full_ml_groups: Dict, base_y: float) -> List[float]: | |
| """Calculate equal positions for full ML groups""" | |
| try: | |
| ml_count = len(full_ml_groups) | |
| if ml_count <= 1: | |
| return [base_y] | |
| # Equal spacing for full ML groups | |
| spacing = 1.5 # Fixed spacing between full ML groups | |
| start_y = base_y - (spacing * (ml_count - 1)) / 2 | |
| positions = [] | |
| for i in range(ml_count): | |
| positions.append(start_y + i * spacing) | |
| return positions | |
| except Exception as e: | |
| print(f"Error calculating full ML positions: {e}") | |
| return [base_y] * len(full_ml_groups) | |
| def _build_genotype_nodes(self, full_ml_node: Dict, genotypes: Dict, base_y: float, matched_ids: List[str], parent_x: float): | |
| """Build genotype nodes with sequences - horizontal line length based on sequence count""" | |
| try: | |
| # Filter genotypes with special sequences | |
| special_genotypes = [] | |
| for genotype, sequences in genotypes.items(): | |
| if any(seq['is_query'] or seq['is_matched'] for seq in sequences): | |
| special_genotypes.append((genotype, sequences)) | |
| if not special_genotypes: | |
| return | |
| # Calculate equal genotype positions (vertical positioning remains equal) | |
| genotype_positions = self._calculate_genotype_positions(special_genotypes, base_y) | |
| # Calculate sequence counts for each genotype to determine horizontal line lengths | |
| genotype_sequence_counts = [] | |
| for genotype, sequences in special_genotypes: | |
| special_sequences = [seq for seq in sequences if seq['is_query'] or seq['is_matched']] | |
| genotype_sequence_counts.append((genotype, sequences, len(special_sequences))) | |
| for gt_idx, (genotype, sequences, sequence_count) in enumerate(genotype_sequence_counts): | |
| y_pos = genotype_positions[gt_idx] | |
| # Determine if this genotype will have vertical attachments | |
| special_sequences = [seq for seq in sequences if seq['is_query'] or seq['is_matched']] | |
| has_vertical = len(special_sequences) > 1 | |
| # Check if this genotype contains the query | |
| contains_query = any(seq['is_query'] for seq in sequences) | |
| # Calculate horizontal line length based on sequence count | |
| horizontal_length = self._determine_genotype_horizontal_line_length( | |
| sequence_count, has_vertical, contains_query | |
| ) | |
| x_pos = parent_x + horizontal_length | |
| # Create genotype node | |
| full_ml_node['children'][genotype] = { | |
| 'name': genotype, | |
| 'type': 'genotype', | |
| 'children': {}, | |
| 'x': x_pos, | |
| 'y': y_pos, | |
| 'sequences': sequences, | |
| 'has_vertical_attachment': has_vertical, | |
| 'horizontal_line_length': horizontal_length, | |
| 'contains_query': contains_query, | |
| 'sequence_count': sequence_count # Store for reference | |
| } | |
| # Add sequences horizontally | |
| self._add_sequences_horizontal( | |
| full_ml_node['children'][genotype], | |
| sequences, | |
| y_pos, | |
| x_pos | |
| ) | |
| except Exception as e: | |
| print(f"Error building genotype nodes: {e}") | |
| def _determine_genotype_horizontal_line_length(self, sequence_count: int, has_vertical: bool, contains_query: bool = False) -> float: | |
| """Determine horizontal line length for genotype nodes based on sequence count""" | |
| try: | |
| base_length = self.base_horizontal_length | |
| # Special case: Genotype containing query sequence gets additional length | |
| query_bonus = 0.5 if contains_query else 0.0 | |
| # Calculate length based on sequence count | |
| # More sequences = longer horizontal line | |
| if sequence_count <= 1: | |
| # Single sequence | |
| length_multiplier = 1.0 | |
| elif sequence_count <= 3: | |
| # 2-3 sequences | |
| length_multiplier = 1.6 | |
| elif sequence_count <= 5: | |
| # 4-5 sequences | |
| length_multiplier = 2.3 | |
| elif sequence_count <= 8: | |
| # 6-8 sequences | |
| length_multiplier = 6.0 | |
| else: | |
| # More than 8 sequences | |
| length_multiplier = 6.0 | |
| # Calculate final length | |
| calculated_length = base_length * length_multiplier + query_bonus | |
| return calculated_length | |
| except Exception as e: | |
| print(f"Error determining genotype horizontal line length: {e}") | |
| return self.base_horizontal_length | |
| def _calculate_genotype_positions(self, special_genotypes: List, base_y: float) -> List[float]: | |
| """Calculate equal positions for genotypes""" | |
| try: | |
| genotype_count = len(special_genotypes) | |
| if genotype_count <= 1: | |
| return [base_y] | |
| # Equal spacing for genotypes | |
| spacing = 1.0 # Fixed spacing between genotypes | |
| start_y = base_y - (spacing * (genotype_count - 1)) / 2 | |
| positions = [] | |
| for i in range(genotype_count): | |
| positions.append(start_y + i * spacing) | |
| return positions | |
| except Exception as e: | |
| print(f"Error calculating genotype positions: {e}") | |
| return [base_y] * len(special_genotypes) | |
| def _add_representative_sequences(self, normalized_ml_node: Dict, representative_sequences: List[Dict], base_y: float, parent_x: float): | |
| """Add representative sequences directly to normalized ML node""" | |
| try: | |
| if not representative_sequences: | |
| return | |
| # Calculate horizontal line length for representative sequences | |
| # Representative sequences get a standard length (not similarity-based since they're not matched) | |
| has_vertical = len(representative_sequences) > 1 | |
| horizontal_length = self._determine_horizontal_line_length('representative', has_vertical) | |
| x_pos = parent_x + horizontal_length | |
| if len(representative_sequences) == 1: | |
| seq = representative_sequences[0] | |
| normalized_ml_node['children'][f"{seq['id']}_rep"] = { | |
| 'name': f"{seq['id']} (Rep)", | |
| 'type': 'representative_sequence', | |
| 'data': seq, | |
| 'x': x_pos, | |
| 'y': base_y, | |
| 'has_vertical_attachment': False, | |
| 'horizontal_line_length': horizontal_length | |
| } | |
| else: | |
| # Equal spacing for multiple representative sequences | |
| positions = self._calculate_sequence_positions(representative_sequences, base_y) | |
| for idx, seq in enumerate(representative_sequences): | |
| normalized_ml_node['children'][f"{seq['id']}_rep"] = { | |
| 'name': f"{seq['id']} (Rep)", | |
| 'type': 'representative_sequence', | |
| 'data': seq, | |
| 'x': x_pos, | |
| 'y': positions[idx], | |
| 'has_vertical_attachment': False, | |
| 'horizontal_line_length': horizontal_length | |
| } | |
| except Exception as e: | |
| print(f"Error adding representative sequences: {e}") | |
| def _add_sequences_horizontal(self, genotype_node: Dict, sequences: List[Dict], base_y: float, parent_x: float): | |
| """Add sequences horizontally with similarity-based line lengths""" | |
| try: | |
| # Define the query line length as the reference (100%) | |
| query_line_length = 3.0 # Base length for query sequence (100%) | |
| # Separate query and matched sequences | |
| query_sequences = [seq for seq in sequences if seq['is_query']] | |
| matched_sequences = [seq for seq in sequences if seq['is_matched'] and not seq['is_query']] | |
| all_special_sequences = query_sequences + matched_sequences | |
| if len(all_special_sequences) == 1: | |
| # Single sequence - direct line with similarity-based length | |
| sequence = all_special_sequences[0] | |
| line_length = self._calculate_similarity_based_line_length(sequence, query_line_length) | |
| x_pos = parent_x + line_length | |
| genotype_node['children'][sequence['id']] = { | |
| 'name': f"{sequence['id']}{' (' + str(sequence['similarity']) + '%)' if sequence['is_matched'] else ''}", | |
| 'type': 'sequence', | |
| 'data': sequence, | |
| 'x': x_pos, | |
| 'y': base_y, | |
| 'has_vertical_attachment': False, | |
| 'similarity_line_length': line_length | |
| } | |
| else: | |
| # Multiple sequences - equal vertical distribution with similarity-based horizontal lengths | |
| sequence_positions = self._calculate_sequence_positions(all_special_sequences, base_y) | |
| for seq_idx, sequence in enumerate(all_special_sequences): | |
| line_length = self._calculate_similarity_based_line_length(sequence, query_line_length) | |
| x_pos = parent_x + line_length | |
| genotype_node['children'][sequence['id']] = { | |
| 'name': f"{sequence['id']}{' (' + str(sequence['similarity']) + '%)' if sequence['is_matched'] else ''}", | |
| 'type': 'sequence', | |
| 'data': sequence, | |
| 'x': x_pos, | |
| 'y': sequence_positions[seq_idx], | |
| 'has_vertical_attachment': False, | |
| 'similarity_line_length': line_length | |
| } | |
| except Exception as e: | |
| print(f"Error adding sequences horizontally: {e}") | |
| def _calculate_similarity_based_line_length(self, sequence: Dict, query_line_length: float) -> float: | |
| """Calculate line length based on similarity percentage relative to query""" | |
| try: | |
| if sequence['is_query']: | |
| # Query sequence gets 100% length | |
| return query_line_length | |
| elif sequence['is_matched']: | |
| # Matched sequences get length proportional to their similarity | |
| similarity = sequence['similarity'] | |
| # Convert similarity percentage to proportional length | |
| proportional_length = (similarity / 100.0) * query_line_length | |
| # Ensure minimum length for visibility | |
| min_length = query_line_length * 0.2 # Minimum 20% of query length | |
| return max(proportional_length, min_length) | |
| else: | |
| # Other sequences get a standard length (50% of query) | |
| return query_line_length * 0.5 | |
| except Exception as e: | |
| print(f"Error calculating similarity-based line length: {e}") | |
| return query_line_length * 0.5 | |
| def _calculate_sequence_positions(self, sequences: List[Dict], base_y: float) -> List[float]: | |
| """Calculate equal positions for sequences""" | |
| try: | |
| seq_count = len(sequences) | |
| if seq_count <= 1: | |
| return [base_y] | |
| # Equal spacing for sequences | |
| spacing = 0.8 # Fixed spacing between sequences | |
| start_y = base_y - (spacing * (seq_count - 1)) / 2 | |
| positions = [] | |
| for i in range(seq_count): | |
| positions.append(start_y + i * spacing) | |
| return positions | |
| except Exception as e: | |
| print(f"Error calculating sequence positions: {e}") | |
| return [base_y] * len(sequences) | |
| def _determine_horizontal_line_length(self, node_type: str, has_vertical: bool, contains_query: bool = False) -> float: | |
| """Determine horizontal line length based on node type and connections""" | |
| try: | |
| base_length = self.base_horizontal_length | |
| # Special case: ML group containing query sequence gets much longer line | |
| if contains_query and node_type == 'normalized_ml_group': | |
| return base_length * 2.5 # Much longer for query ML group | |
| # If this node has a vertical line attachment (connects to multiple children) | |
| if has_vertical: | |
| # Find the current longest horizontal line with vertical | |
| current_max = base_length | |
| for tracked_length in self.horizontal_line_tracker: | |
| if tracked_length > current_max: | |
| current_max = tracked_length | |
| # Make this line incrementally longer | |
| new_length = current_max + 0.3 | |
| self.horizontal_line_tracker.append(new_length) | |
| return new_length | |
| else: | |
| # Direct connection (no vertical), use base length | |
| return base_length | |
| except Exception as e: | |
| print(f"Error determining horizontal line length: {e}") | |
| return self.base_horizontal_length | |
| def _identify_query_ml_group(self, normalized_ml_groups: Dict): | |
| """Identify which ML group contains the query sequence""" | |
| try: | |
| for base_ml, ml_data in normalized_ml_groups.items(): | |
| if ml_data['has_special_sequences']: | |
| for full_ml_name, genotypes in ml_data['full_ml_groups'].items(): | |
| for genotype, sequences in genotypes.items(): | |
| if any(seq['is_query'] for seq in sequences): | |
| self.query_ml_group = base_ml | |
| return | |
| except Exception as e: | |
| print(f"Error identifying query ML group: {e}") | |
| def _identify_query_ml_group(self, normalized_ml_groups: Dict): | |
| """Identify which ML group contains the query sequence""" | |
| try: | |
| for base_ml, ml_data in normalized_ml_groups.items(): | |
| if ml_data['has_special_sequences']: | |
| for full_ml_name, genotypes in ml_data['full_ml_groups'].items(): | |
| for genotype, sequences in genotypes.items(): | |
| if any(seq['is_query'] for seq in sequences): | |
| self.query_ml_group = base_ml | |
| return | |
| except Exception as e: | |
| print(f"Error identifying query ML group: {e}") | |
| def _calculate_sequence_x_position_horizontal(self, sequence: Dict, max_similarity: float) -> float: | |
| """Calculate X position based on similarity percentage for horizontal layout""" | |
| # This function is now replaced by _calculate_similarity_based_line_length | |
| # Keeping for backward compatibility, but the new approach is used in _add_sequences_horizontal | |
| base_x = 0 # Relative to parent genotype node | |
| query_line_length = 3.0 # Reference length for query (100%) | |
| if sequence['is_query']: | |
| return base_x + query_line_length # 100% length for query | |
| elif sequence['is_matched']: | |
| # Line length varies based on similarity percentage | |
| similarity = sequence['similarity'] | |
| proportional_length = (similarity / 100.0) * query_line_length | |
| min_length = query_line_length * 0.2 # Minimum 20% of query length | |
| return base_x + max(proportional_length, min_length) | |
| else: | |
| return base_x + (query_line_length * 0.5) # 50% length for other sequences | |
| def create_interactive_tree(self, matched_ids: List[str], actual_percentage: float): | |
| try: | |
| print("🎨 Creating horizontal interactive tree visualization...") | |
| # Prepare data for plotting | |
| edge_x = [] | |
| edge_y = [] | |
| node_x = [] | |
| node_y = [] | |
| node_colors = [] | |
| node_text = [] | |
| node_hover = [] | |
| node_sizes = [] | |
| # Updated color scheme for new node types | |
| colors = { | |
| 'root': '#FF0000', # Red for root | |
| 'normalized_ml_group': '#FFB6C1', # Light pink for normalized ML groups | |
| 'full_ml_group': '#FF69B4', # Hot pink for full ML groups | |
| 'genotype': '#FFD700', # Gold for genotypes | |
| 'representative_sequence': '#FFA500', # Orange for representative sequences | |
| 'query_sequence': '#4B0082', # Dark purple for query | |
| 'matched_sequence': '#6A5ACD', # Slate blue for matched | |
| 'other_sequence': '#87CEEB' # Sky blue for others | |
| } | |
| def add_horizontal_edges(parent_x, parent_y, children_dict): | |
| """Add horizontal connecting lines with proper vertical line sizing""" | |
| if not children_dict: | |
| return | |
| children_list = list(children_dict.values()) | |
| if len(children_list) == 1: | |
| # Single child - direct horizontal line | |
| child = children_list[0] | |
| edge_x.extend([parent_x, child['x'], None]) | |
| edge_y.extend([parent_y, child['y'], None]) | |
| else: | |
| # Multiple children - horizontal line with vertical distribution | |
| # Calculate the intermediate x position (where vertical line will be) | |
| child_x_positions = [child['x'] for child in children_list] | |
| min_child_x = min(child_x_positions) | |
| intermediate_x = parent_x + (min_child_x - parent_x) * 0.8 # 80% of the way to nearest child | |
| # Horizontal line to intermediate point | |
| edge_x.extend([parent_x, intermediate_x, None]) | |
| edge_y.extend([parent_y, parent_y, None]) | |
| # Calculate vertical line range to fit exactly all children | |
| child_y_positions = [child['y'] for child in children_list] | |
| min_y, max_y = min(child_y_positions), max(child_y_positions) | |
| # Vertical line sized exactly to fit all children | |
| edge_x.extend([intermediate_x, intermediate_x, None]) | |
| edge_y.extend([min_y, max_y, None]) | |
| # Horizontal lines from vertical line to each child | |
| for child in children_list: | |
| edge_x.extend([intermediate_x, child['x'], None]) | |
| edge_y.extend([child['y'], child['y'], None]) | |
| def get_node_color_and_size(node): | |
| """Determine node color and size based on type and content""" | |
| if node['type'] == 'sequence': | |
| if node['data']['is_query']: | |
| return colors['query_sequence'], 10 # Reduced size for compactness | |
| elif node['data']['is_matched']: | |
| return colors['matched_sequence'], 8 | |
| else: | |
| return colors['other_sequence'], 6 | |
| elif node['type'] == 'representative_sequence': | |
| return colors['representative_sequence'], 7 | |
| elif node['type'] == 'normalized_ml_group': | |
| # Larger size if it has special sequences | |
| size = 9 if node.get('has_special_sequences', False) else 7 | |
| return colors['normalized_ml_group'], size | |
| elif node['type'] == 'full_ml_group': | |
| return colors['full_ml_group'], 8 | |
| elif node['type'] == 'genotype': | |
| return colors['genotype'], 7 | |
| else: | |
| return colors.get(node['type'], '#000000'), 7 | |
| def create_node_text(node): | |
| """Create appropriate text label for each node type""" | |
| if node['type'] == 'sequence': | |
| if node['data']['is_matched'] and not node['data']['is_query']: | |
| return f"{node['name']}" | |
| else: | |
| return node['name'] | |
| elif node['type'] == 'representative_sequence': | |
| return node['name'] | |
| elif node['type'] == 'normalized_ml_group': | |
| # Add indicator if it has special sequences | |
| suffix = " *" if node.get('has_special_sequences', False) else "" | |
| return f"{node['name']}{suffix}" | |
| else: | |
| return node['name'] | |
| def create_hover_text(node): | |
| """Create detailed hover text for each node type""" | |
| if node['type'] == 'sequence': | |
| data = node['data']['data'] | |
| hover_text = ( | |
| f"<b>{node['name']}</b><br>" | |
| f"Type: {'Query Sequence' if node['data']['is_query'] else 'Matched Sequence' if node['data']['is_matched'] else 'Other Sequence'}<br>" | |
| f"ML Group: {data.get('ML', 'N/A')}<br>" | |
| f"Genotype: {data.get('Genotype', 'N/A')}<br>" | |
| f"Host: {data.get('Host', 'N/A')}<br>" | |
| f"Country: {data.get('Country', 'N/A')}<br>" | |
| f"Isolate: {data.get('Isolate', 'N/A')}<br>" | |
| f"Year: {data.get('Year', 'N/A')}" | |
| ) | |
| if node['data']['is_matched']: | |
| hover_text += f"<br><b>Similarity: {node['data']['similarity']}%</b>" | |
| elif node['type'] == 'representative_sequence': | |
| data = node['data']['data'] | |
| hover_text = ( | |
| f"<b>{node['name']}</b><br>" | |
| f"Type: Representative Sequence<br>" | |
| f"ML Group: {data.get('ML', 'N/A')}<br>" | |
| f"Genotype: {data.get('Genotype', 'N/A')}<br>" | |
| f"Host: {data.get('Host', 'N/A')}<br>" | |
| f"Country: {data.get('Country', 'N/A')}" | |
| ) | |
| elif node['type'] == 'normalized_ml_group': | |
| hover_text = f"<b>{node['name']}</b><br>Type: Normalized ML Group" | |
| if node.get('has_special_sequences', False): | |
| hover_text += "<br>Contains query/matched sequences" | |
| else: | |
| hover_text += "<br>Representative sequences only" | |
| elif node['type'] == 'full_ml_group': | |
| hover_text = f"<b>{node['name']}</b><br>Type: Full ML Group" | |
| if 'sequences_count' in node: | |
| hover_text += f"<br>Total Sequences: {node['sequences_count']}" | |
| elif node['type'] == 'genotype': | |
| hover_text = f"<b>{node['name']}</b><br>Type: Genotype" | |
| if 'sequences' in node: | |
| special_count = sum(1 for seq in node['sequences'] if seq['is_query'] or seq['is_matched']) | |
| hover_text += f"<br>Special Sequences: {special_count}/{len(node['sequences'])}" | |
| else: | |
| hover_text = f"<b>{node['name']}</b><br>Type: {node['type'].replace('_', ' ').title()}" | |
| return hover_text | |
| def add_node_and_edges(node, parent_x=None, parent_y=None): | |
| """Recursively add nodes and edges to the plot with equal spacing structure.""" | |
| x, y = node['x'], node['y'] | |
| node_x.append(x) | |
| node_y.append(y) | |
| # Get node color and size | |
| color, size = get_node_color_and_size(node) | |
| node_colors.append(color) | |
| node_sizes.append(size) | |
| # Create node text and hover | |
| node_text.append(create_node_text(node)) | |
| node_hover.append(create_hover_text(node)) | |
| # Process children with equal spacing structure | |
| if 'children' in node and node['children']: | |
| add_horizontal_edges(x, y, node['children']) | |
| for child in node['children'].values(): | |
| add_node_and_edges(child, x, y) | |
| # Build the plot data starting from root | |
| root_node = self.tree_structure['root'] | |
| add_node_and_edges(root_node) | |
| # Add horizontal edges for root level | |
| if root_node['children']: | |
| add_horizontal_edges(root_node['x'], root_node['y'], root_node['children']) | |
| # Create the figure | |
| fig = go.Figure() | |
| # Add edges | |
| fig.add_trace(go.Scatter( | |
| x=edge_x, y=edge_y, | |
| mode='lines', | |
| line=dict(width=1, color='gray', dash='solid'), # Thinner lines for compactness | |
| hoverinfo='none', | |
| showlegend=False, | |
| name='Edges' | |
| )) | |
| # Add nodes | |
| fig.add_trace(go.Scatter( | |
| x=node_x, y=node_y, | |
| mode='markers+text', | |
| marker=dict( | |
| size=node_sizes, | |
| color=node_colors, | |
| line=dict(width=1, color='black'), # Thinner borders | |
| opacity=0.85 | |
| ), | |
| text=node_text, | |
| textposition="middle right", | |
| textfont=dict(size=9, color="black"), # Smaller font for compactness | |
| hoverinfo='text', | |
| hovertext=node_hover, | |
| showlegend=False, | |
| name='Nodes' | |
| )) | |
| # Calculate proper layout dimensions to ensure everything fits | |
| if node_x and node_y: | |
| # Get the actual data bounds | |
| min_x, max_x = min(node_x), max(node_x) | |
| min_y, max_y = min(node_y), max(node_y) | |
| # Calculate ranges | |
| x_range = max_x - min_x | |
| y_range = max_y - min_y | |
| # Add padding to ensure nothing is cut off (20% padding on each side) | |
| x_padding = x_range * 0.2 if x_range > 0 else 1 | |
| y_padding = y_range * 0.2 if y_range > 0 else 1 | |
| # Set axis ranges with padding | |
| x_axis_range = [min_x - x_padding, max_x + x_padding] | |
| y_axis_range = [min_y - y_padding, max_y + y_padding] | |
| # Compact but sufficient sizing | |
| width = min(1400, max(800, int(x_range * 80 + 400))) # Cap max width | |
| height = min(900, max(500, int(y_range * 40 + 300))) # Cap max height | |
| else: | |
| width, height = 800, 500 | |
| x_axis_range = None | |
| y_axis_range = None | |
| # Update layout for compact horizontal tree with proper bounds | |
| fig.update_layout( | |
| title=dict( | |
| text=f"Compact Horizontal Phylogenetic Tree (ML-Based)<br>" | |
| f"Query: {self.query_id} | Similarity: {actual_percentage}% | " | |
| f"Matched: {len(matched_ids)}", | |
| x=0.5, | |
| font=dict(size=12) # Smaller title for compactness | |
| ), | |
| xaxis=dict( | |
| showgrid=False, | |
| gridcolor='lightgray', | |
| gridwidth=0.3, # Very thin grid lines | |
| zeroline=False, | |
| showticklabels=False, | |
| range=x_axis_range, # Set explicit range to prevent cutoff | |
| fixedrange=False, # Allow zooming if needed | |
| automargin=True # Automatically adjust margins | |
| ), | |
| yaxis=dict( | |
| showgrid=False, | |
| gridcolor='lightgray', | |
| gridwidth=0.3, # Very thin grid lines | |
| zeroline=False, | |
| showticklabels=False, | |
| range=y_axis_range, # Set explicit range to prevent cutoff | |
| fixedrange=False, # Allow zooming if needed | |
| automargin=True # Automatically adjust margins | |
| ), | |
| plot_bgcolor="white", | |
| paper_bgcolor="white", | |
| hovermode="closest", | |
| width=width, | |
| height=height, | |
| margin=dict(l=20, r=100, t=40, b=10), # Adequate margins, extra right margin for text | |
| autosize=False, # Don't auto-resize | |
| showlegend=True, | |
| legend=dict( | |
| x=1.02, # Position legend outside plot area | |
| y=1, | |
| xanchor='left', | |
| yanchor='top', | |
| bgcolor='rgba(255,255,255,0.8)', | |
| bordercolor='gray', | |
| borderwidth=1, | |
| font=dict(size=10) # Smaller legend font | |
| ) | |
| ) | |
| # Add comprehensive legend with smaller markers | |
| legend_elements = [ | |
| dict(name="Root", marker=dict(color=colors['root'], size=8)), | |
| dict(name="Normalized ML Groups", marker=dict(color=colors['normalized_ml_group'], size=8)), | |
| dict(name="Full ML Groups", marker=dict(color=colors['full_ml_group'], size=8)), | |
| dict(name="Genotypes", marker=dict(color=colors['genotype'], size=8)), | |
| dict(name="Query Sequence", marker=dict(color=colors['query_sequence'], size=10)), | |
| dict(name="Similar Sequences", marker=dict(color=colors['matched_sequence'], size=9)), | |
| dict(name="Representative Sequences", marker=dict(color=colors['representative_sequence'], size=8)), | |
| dict(name="Other Sequences", marker=dict(color=colors['other_sequence'], size=7)) | |
| ] | |
| for i, element in enumerate(legend_elements): | |
| fig.add_trace(go.Scatter( | |
| x=[None], y=[None], | |
| mode='markers', | |
| marker=element['marker'], | |
| name=element['name'], | |
| showlegend=True | |
| )) | |
| # Configure modebar for better user experience | |
| config = { | |
| 'displayModeBar': True, | |
| 'displaylogo': False, | |
| 'modeBarButtonsToRemove': ['select2d', 'lasso2d'], | |
| 'toImageButtonOptions': { | |
| 'format': 'png', | |
| 'filename': 'phylogenetic_tree', | |
| 'height': height, | |
| 'width': width, | |
| 'scale': 2 | |
| } | |
| } | |
| # Save outputs | |
| try: | |
| fig.write_html("phylogenetic_tree_normalized_horizontal.html", config=config) | |
| print("✓ Compact horizontal interactive tree saved as 'phylogenetic_tree_normalized_horizontal.html'") | |
| except Exception as e: | |
| print(f"Warning: Could not save HTML file: {e}") | |
| # Display the figure with config | |
| try: | |
| fig.show(config=config) | |
| except Exception as e: | |
| print(f"Warning: Could not display figure: {e}") | |
| return fig | |
| except Exception as e: | |
| print(f"Error creating compact horizontal interactive tree: {e}") | |
| return None | |
| def create_sequence_alignment(self, sequence_ids: List[str]) -> Optional[MultipleSeqAlignment]: | |
| try: | |
| print("🧬 Creating multiple sequence alignment...") | |
| # Get sequences | |
| sequences = [] | |
| for seq_id in sequence_ids: | |
| try: | |
| row = self.data[self.data['Accession Number'] == seq_id] | |
| if not row.empty: | |
| f_gene = str(row.iloc[0]['F-gene']) | |
| # Clean sequence (remove non-nucleotide characters) | |
| clean_seq = re.sub(r'[^ATGCN-]', '', f_gene.upper()) | |
| if len(clean_seq) > 10: # Minimum sequence length | |
| seq_record = SeqRecord(Seq(clean_seq), id=seq_id, description="") | |
| sequences.append(seq_record) | |
| except Exception as e: | |
| print(f"Warning: Skipping sequence {seq_id}: {e}") | |
| continue | |
| if len(sequences) < 2: | |
| print("❌ Need at least 2 valid sequences for alignment") | |
| return None | |
| # Simple alignment (you might want to use MUSCLE or CLUSTAL for better results) | |
| aligned_sequences = self._simple_alignment(sequences) | |
| print(f"✓ Alignment created with {len(aligned_sequences)} sequences") | |
| return MultipleSeqAlignment(aligned_sequences) | |
| except Exception as e: | |
| print(f"Error creating alignment: {e}") | |
| return None | |
| def _simple_alignment(self, sequences: List[SeqRecord]) -> List[SeqRecord]: | |
| try: | |
| # Find maximum length | |
| max_length = max(len(seq.seq) for seq in sequences) | |
| # Cap maximum length to prevent memory issues | |
| if max_length > 10000: | |
| max_length = 10000 | |
| print(f"Warning: Sequences truncated to {max_length} bp") | |
| # Pad sequences to same length | |
| aligned_sequences = [] | |
| for seq in sequences: | |
| seq_str = str(seq.seq)[:max_length] # Truncate if too long | |
| if len(seq_str) < max_length: | |
| # Pad with gaps at the end | |
| padded_seq = seq_str + '-' * (max_length - len(seq_str)) | |
| else: | |
| padded_seq = seq_str | |
| aligned_sequences.append(SeqRecord(Seq(padded_seq), id=seq.id, description=seq.description)) | |
| return aligned_sequences | |
| except Exception as e: | |
| print(f"Error in simple alignment: {e}") | |
| return sequences # Return original sequences as fallback | |
| def calculate_ml_distances(self, alignment: MultipleSeqAlignment) -> np.ndarray: | |
| try: | |
| print("📊 Calculating ML distances...") | |
| # Convert alignment to numeric matrix | |
| seq_matrix = self._alignment_to_matrix(alignment) | |
| n_sequences = len(alignment) | |
| if n_sequences == 0: | |
| return np.array([]) | |
| # Initialize distance matrix | |
| distance_matrix = np.zeros((n_sequences, n_sequences)) | |
| # Calculate pairwise ML distances | |
| for i in range(n_sequences): | |
| for j in range(i + 1, n_sequences): | |
| try: | |
| ml_distance = self._calculate_ml_distance_pair(seq_matrix[i], seq_matrix[j]) | |
| distance_matrix[i][j] = ml_distance | |
| distance_matrix[j][i] = ml_distance | |
| except Exception as e: | |
| print(f"Warning: Error calculating distance between sequences {i} and {j}: {e}") | |
| # Use maximum distance as fallback | |
| distance_matrix[i][j] = 1.0 | |
| distance_matrix[j][i] = 1.0 | |
| print("✓ ML distances calculated") | |
| return distance_matrix | |
| except Exception as e: | |
| print(f"Error calculating ML distances: {e}") | |
| return np.array([]) | |
| def _alignment_to_matrix(self, alignment: MultipleSeqAlignment) -> np.ndarray: | |
| try: | |
| # Nucleotide to number mapping | |
| nucleotide_map = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4, '-': 5} | |
| matrix = [] | |
| for record in alignment: | |
| sequence = str(record.seq).upper() | |
| numeric_seq = [nucleotide_map.get(nuc, 4) for nuc in sequence] | |
| matrix.append(numeric_seq) | |
| return np.array(matrix) | |
| except Exception as e: | |
| print(f"Error converting alignment to matrix: {e}") | |
| return np.array([]) | |
| def _calculate_ml_distance_pair(self, seq1: np.ndarray, seq2: np.ndarray) -> float: | |
| try: | |
| if len(seq1) == 0 or len(seq2) == 0: | |
| return 1.0 | |
| # Count differences (excluding gaps and N's) | |
| valid_positions = (seq1 < 4) & (seq2 < 4) # Exclude N's and gaps | |
| if np.sum(valid_positions) == 0: | |
| return 1.0 # Maximum distance if no valid comparisons | |
| differences = np.sum(seq1[valid_positions] != seq2[valid_positions]) | |
| total_valid = np.sum(valid_positions) | |
| if total_valid == 0: | |
| return 1.0 | |
| # Calculate proportion of differences | |
| p = differences / total_valid | |
| # Jukes-Cantor correction | |
| if p >= 0.75: | |
| return 1.0 # Maximum distance | |
| # JC distance formula: -3/4 * ln(1 - 4p/3) | |
| try: | |
| jc_distance = -0.75 * np.log(1 - (4 * p / 3)) | |
| return min(max(jc_distance, 0.0), 1.0) # Clamp between 0 and 1 | |
| except (ValueError, RuntimeWarning): | |
| return 1.0 # Return maximum distance if log calculation fails | |
| except Exception as e: | |
| return 1.0 # Return maximum distance on error | |
| def construct_ml_tree(self, alignment: MultipleSeqAlignment) -> Optional[BaseTree.Tree]: | |
| try: | |
| print("🌳 Constructing Maximum Likelihood tree...") | |
| # Calculate ML distance matrix | |
| distance_matrix = self.calculate_ml_distances(alignment) | |
| if distance_matrix.size == 0: | |
| return None | |
| # Create sequence names list | |
| sequence_names = [record.id for record in alignment] | |
| # Build tree using neighbor-joining on ML distances | |
| tree = self._build_nj_tree_from_distances(distance_matrix, sequence_names) | |
| # Optimize branch lengths using ML (with recursion protection) | |
| if tree: | |
| tree = self._optimize_branch_lengths_ml_safe(tree, alignment) | |
| print("✓ ML tree constructed successfully") | |
| return tree | |
| except Exception as e: | |
| print(f"Error constructing ML tree: {e}") | |
| return None | |
| def _build_nj_tree_from_distances(self, distance_matrix: np.ndarray, sequence_names: List[str]) -> Optional[BaseTree.Tree]: | |
| try: | |
| from Bio.Phylo.TreeConstruction import DistanceMatrix, DistanceTreeConstructor | |
| # Validate inputs | |
| if distance_matrix.shape[0] != len(sequence_names): | |
| print("Error: Distance matrix size doesn't match sequence names") | |
| return None | |
| # Convert numpy array to Bio.Phylo distance matrix format | |
| matrix_data = [] | |
| for i in range(len(sequence_names)): | |
| row = [] | |
| for j in range(i + 1): | |
| if i == j: | |
| row.append(0.0) | |
| else: | |
| # Ensure distance is valid | |
| dist = float(distance_matrix[i][j]) | |
| if np.isnan(dist) or np.isinf(dist): | |
| dist = 1.0 | |
| row.append(max(0.0, dist)) # Ensure non-negative | |
| matrix_data.append(row) | |
| # Create DistanceMatrix object | |
| dm = DistanceMatrix(names=sequence_names, matrix=matrix_data) | |
| # Build tree using Neighbor-Joining | |
| constructor = DistanceTreeConstructor() | |
| tree = constructor.nj(dm) | |
| # Validate tree structure | |
| if tree and self._validate_tree_structure(tree): | |
| return tree | |
| else: | |
| print("Warning: Tree structure validation failed") | |
| return tree # Return anyway, might still be usable | |
| except Exception as e: | |
| print(f"Error building NJ tree: {e}") | |
| return None | |
| def _validate_tree_structure(self, tree: BaseTree.Tree, max_depth: int = 100) -> bool: | |
| try: | |
| visited = set() | |
| def check_node(node, depth=0): | |
| if depth > max_depth: | |
| return False | |
| # Check for circular references | |
| node_id = id(node) | |
| if node_id in visited: | |
| return False | |
| visited.add(node_id) | |
| # Check children | |
| for child in getattr(node, 'clades', []): | |
| if not check_node(child, depth + 1): | |
| return False | |
| return True | |
| return check_node(tree.root if hasattr(tree, 'root') else tree) | |
| except Exception: | |
| return False | |
| def _optimize_branch_lengths_ml_safe(self, tree: BaseTree.Tree, alignment: MultipleSeqAlignment) -> BaseTree.Tree: | |
| try: | |
| print("🔧 Optimizing branch lengths with ML...") | |
| # Set recursion limit temporarily | |
| old_limit = sys.getrecursionlimit() | |
| sys.setrecursionlimit(1000) | |
| try: | |
| # Convert alignment to matrix | |
| seq_matrix = self._alignment_to_matrix(alignment) | |
| if seq_matrix.size == 0: | |
| print("Warning: Empty sequence matrix, skipping optimization") | |
| return tree | |
| # Get all internal and external nodes with depth tracking | |
| all_clades = self._get_clades_safe(tree) | |
| # Simple branch length optimization | |
| for clade in all_clades: | |
| if hasattr(clade, 'branch_length') and clade.branch_length is not None: | |
| try: | |
| # Calculate optimal branch length based on likelihood | |
| optimal_length = self._calculate_optimal_branch_length_safe(clade, seq_matrix) | |
| clade.branch_length = max(optimal_length, 0.001) # Minimum branch length | |
| except Exception as e: | |
| print(f"Warning: Failed to optimize branch for clade: {e}") | |
| # Keep original branch length | |
| pass | |
| print("✓ Branch lengths optimized") | |
| finally: | |
| # Restore original recursion limit | |
| sys.setrecursionlimit(old_limit) | |
| return tree | |
| except Exception as e: | |
| print(f"Warning: Branch length optimization failed: {e}") | |
| return tree | |
| def _get_clades_safe(self, tree: BaseTree.Tree, max_depth: int = 50) -> List: | |
| clades = [] | |
| visited = set() | |
| def traverse_node(node, depth=0): | |
| if depth > max_depth or id(node) in visited: | |
| return | |
| visited.add(id(node)) | |
| clades.append(node) | |
| # Traverse children safely | |
| try: | |
| children = getattr(node, 'clades', []) | |
| for child in children: | |
| traverse_node(child, depth + 1) | |
| except Exception: | |
| pass # Skip problematic nodes | |
| try: | |
| root = tree.root if hasattr(tree, 'root') else tree | |
| traverse_node(root) | |
| except Exception as e: | |
| print(f"Warning: Tree traversal error: {e}") | |
| return clades | |
| def _calculate_optimal_branch_length_safe(self, clade, seq_matrix: np.ndarray) -> float: | |
| try: | |
| # Simplified ML branch length estimation | |
| if not hasattr(clade, 'branch_length') or clade.branch_length is None: | |
| return 0.1 # Default branch length | |
| current_length = float(clade.branch_length) | |
| # Validate current length | |
| if np.isnan(current_length) or np.isinf(current_length) or current_length <= 0: | |
| return 0.1 | |
| # Simple optimization based on sequence characteristics | |
| if hasattr(clade, 'name') and clade.name: | |
| # For terminal nodes | |
| return min(max(current_length * 0.9, 0.001), 1.0) | |
| else: | |
| # For internal nodes | |
| return min(max(current_length * 1.1, 0.001), 1.0) | |
| except Exception: | |
| return 0.1 # Safe default | |
| def calculate_ml_likelihood_safe(self, tree: BaseTree.Tree, alignment: MultipleSeqAlignment) -> float: | |
| try: | |
| print("📈 Calculating tree likelihood...") | |
| seq_matrix = self._alignment_to_matrix(alignment) | |
| if seq_matrix.size == 0: | |
| return -np.inf | |
| # Simplified likelihood calculation using Jukes-Cantor model | |
| total_log_likelihood = 0.0 | |
| # For each site in the alignment (sample subset to avoid memory issues) | |
| n_sites = min(seq_matrix.shape[1], 1000) # Limit sites for performance | |
| for site in range(0, n_sites, max(1, n_sites // 100)): # Sample sites | |
| try: | |
| site_pattern = seq_matrix[:, site] | |
| # Skip sites with gaps or N's | |
| valid_positions = site_pattern < 4 | |
| if np.sum(valid_positions) < 2: | |
| continue | |
| # Calculate likelihood for this site pattern | |
| site_likelihood = self._calculate_site_likelihood_safe(tree, site_pattern) | |
| if site_likelihood > 0: | |
| total_log_likelihood += np.log(site_likelihood) | |
| except Exception as e: | |
| print(f"Warning: Error processing site {site}: {e}") | |
| continue | |
| print(f"✓ Tree likelihood calculated: {total_log_likelihood:.2f}") | |
| return total_log_likelihood | |
| except Exception as e: | |
| print(f"Error calculating likelihood: {e}") | |
| return -np.inf | |
| def _calculate_site_likelihood_safe(self, tree: BaseTree.Tree, site_pattern: np.ndarray) -> float: | |
| try: | |
| # Count nucleotide frequencies at this site | |
| valid_nucs = site_pattern[site_pattern < 4] | |
| if len(valid_nucs) == 0: | |
| return 1.0 | |
| # Simple likelihood based on nucleotide diversity | |
| unique_nucs = len(np.unique(valid_nucs)) | |
| total_nucs = len(valid_nucs) | |
| # Higher diversity = lower likelihood of simple evolution | |
| diversity_factor = unique_nucs / 4.0 # Normalize by 4 nucleotides | |
| # Simple likelihood model | |
| likelihood = np.exp(-diversity_factor * total_nucs * 0.1) | |
| return max(likelihood, 1e-10) # Avoid zero likelihood | |
| except Exception: | |
| return 1e-10 # Safe fallback | |
| def perform_ml_analysis_safe(self, matched_ids: List[str]) -> Dict: | |
| try: | |
| print("\n🧬 PERFORMING MAXIMUM LIKELIHOOD ANALYSIS") | |
| print("="*50) | |
| # Include query sequence in analysis | |
| all_sequences = [self.query_id] + [seq_id for seq_id in matched_ids if seq_id != self.query_id] | |
| # Limit number of sequences to prevent memory issues | |
| if len(all_sequences) > 20: | |
| print(f"Warning: Limiting analysis to 20 sequences (had {len(all_sequences)})") | |
| all_sequences = all_sequences[:20] | |
| if len(all_sequences) < 3: | |
| print("❌ Need at least 3 sequences for ML analysis") | |
| return {} | |
| # Step 1: Create multiple sequence alignment | |
| alignment = self.create_sequence_alignment(all_sequences) | |
| if not alignment: | |
| return {} | |
| # Step 2: Calculate ML distances | |
| distance_matrix = self.calculate_ml_distances(alignment) | |
| if distance_matrix.size == 0: | |
| return {} | |
| # Step 3: Construct ML tree | |
| ml_tree = self.construct_ml_tree(alignment) | |
| if not ml_tree: | |
| return {} | |
| # Step 4: Calculate tree likelihood (safely) | |
| log_likelihood = self.calculate_ml_likelihood_safe(ml_tree, alignment) | |
| # Step 5: Prepare results | |
| ml_results = { | |
| 'tree': ml_tree, | |
| 'alignment': alignment, | |
| 'distance_matrix': distance_matrix, | |
| 'log_likelihood': log_likelihood, | |
| 'sequence_count': len(all_sequences), | |
| 'alignment_length': len(alignment[0]) if alignment else 0 | |
| } | |
| print(f"✅ ML analysis completed successfully") | |
| print(f" Sequences analyzed: {len(all_sequences)}") | |
| print(f" Alignment length: {ml_results['alignment_length']}") | |
| print(f" Log-likelihood: {log_likelihood:.2f}") | |
| return ml_results | |
| except Exception as e: | |
| print(f"❌ ML analysis failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return {} | |
| def build_tree_structure_with_ml_safe(self, matched_ids: List[str]) -> Dict: | |
| try: | |
| print("🌳 Building ML-enhanced tree structure...") | |
| # Perform ML analysis first | |
| ml_results = self.perform_ml_analysis_safe(matched_ids) | |
| # Build the original hierarchical structure | |
| tree_structure = self.build_tree_structure(matched_ids) | |
| # Enhance with ML information | |
| if ml_results and 'tree' in ml_results: | |
| tree_structure['ml_analysis'] = { | |
| 'log_likelihood': ml_results['log_likelihood'], | |
| 'sequence_count': ml_results['sequence_count'], | |
| 'alignment_length': ml_results['alignment_length'], | |
| 'ml_tree_available': True | |
| } | |
| # Store ML tree for later use | |
| self.ml_tree = ml_results['tree'] | |
| self.ml_alignment = ml_results.get('alignment') | |
| print("✓ Tree structure enhanced with ML analysis") | |
| else: | |
| tree_structure['ml_analysis'] = { | |
| 'ml_tree_available': False, | |
| 'error': 'ML analysis failed' | |
| } | |
| print("⚠️ ML analysis failed, using standard tree structure") | |
| return tree_structure | |
| except Exception as e: | |
| print(f"Error building ML-enhanced tree structure: {e}") | |
| # Fallback to original method | |
| try: | |
| return self.build_tree_structure(matched_ids) | |
| except Exception as e2: | |
| print(f"Fallback also failed: {e2}") | |
| return {'error': 'Both ML and standard tree construction failed'} | |
| def _print_tree_topology(self, tree, max_depth=3, current_depth=0, prefix=""): | |
| if current_depth > max_depth: | |
| return | |
| try: | |
| # Get all clades at current level | |
| clades = list(tree.find_clades()) | |
| for i, clade in enumerate(clades[:5]): # Limit to first 5 for readability | |
| branch_info = "" | |
| if clade.branch_length is not None: | |
| branch_info = f" (len: {clade.branch_length:.4f})" | |
| if clade.is_terminal(): | |
| node_name = clade.name or "Terminal" | |
| print(f" {prefix}├── {node_name}{branch_info}") | |
| else: | |
| node_name = clade.name or f"Internal_{i}" | |
| print(f" {prefix}├── {node_name}{branch_info}") | |
| if current_depth < max_depth - 1 and not clade.is_terminal(): | |
| # Show children (simplified) | |
| children = list(clade.find_clades()) | |
| if len(children) > 1: | |
| for j, child in enumerate(children[1:3]): # Show max 2 children | |
| child_name = child.name or f"Node_{j}" | |
| child_branch = f" (len: {child.branch_length:.4f})" if child.branch_length else "" | |
| print(f" {prefix}│ ├── {child_name}{child_branch}") | |
| except Exception as e: | |
| print(f" Error displaying topology: {e}") | |
| def main(): | |
| print("\n" + "="*70) | |
| print("🧬 PHYLOGENETIC TREE ANALYZER - ADVANCED ML-BASED ANALYSIS") | |
| print("="*70) | |
| print("Version 2.0 | AI-Enhanced Similarity Matching") | |
| print("Interactive Visualization with Variable Line Lengths") | |
| print("="*70) | |
| # Initialize the analyzer | |
| analyzer = PhylogeneticTreeAnalyzer() | |
| try: | |
| # Step 1: Load data | |
| while True: | |
| data_file = "f cleaned.csv" | |
| if not data_file: | |
| print("❌ Please provide a file path.") | |
| continue | |
| if not Path(data_file).exists(): | |
| print(f"❌ File not found: {data_file}") | |
| continue | |
| if analyzer.load_data(data_file): | |
| break | |
| else: | |
| print("❌ Failed to load data. Please check file format.") | |
| continue | |
| # Step 2: Train AI model automatically | |
| print("\n⏳ Training AI model... This may take a few moments.", flush=True) | |
| start_time = time.time() | |
| if analyzer.train_ai_model(): | |
| elapsed = time.time() - start_time | |
| print(f"✅ AI model training completed in {elapsed:.1f} seconds", flush=True) | |
| else: | |
| print("⚠️ AI model training failed, continuing with basic analysis", flush=True) | |
| # Step 3: Get query sequence | |
| while True: | |
| print("\n🔍 QUERY SEQUENCE INPUT:") | |
| print(" You can provide:") | |
| print(" 1. Accession Number (e.g., 'MH087032') - from your dataset") | |
| print(" 2. ANY F-gene nucleotide sequence (A, T, G, C)") | |
| print(" 3. Novel sequences will be compared against your dataset") | |
| print(" Note: Minimum sequence length is 10 nucleotides") | |
| query_input = input("\nEnter query sequence or ID: ").strip() | |
| if not query_input: | |
| print("❌ Please provide a query sequence or ID.") | |
| continue | |
| if analyzer.find_query_sequence(query_input): | |
| break | |
| else: | |
| retry = input("❌ Invalid input. Try again? (y/n): ").strip().lower() | |
| if retry != 'y': | |
| print("👋 Analysis cancelled.") | |
| return | |
| # Step 4: Set similarity percentage | |
| while True: | |
| try: | |
| print(f"\n📊 SIMILARITY THRESHOLD:") | |
| print(f" - Higher values (90-99%): Find very similar sequences") | |
| print(f" - Lower values (70-89%): Find more distantly related sequences") | |
| similarity_input = input(f"Enter target similarity percentage (1-99) [85]: ").strip() | |
| if not similarity_input: | |
| target_percentage = 85.0 # Lowered default for novel sequences | |
| else: | |
| target_percentage = float(similarity_input) | |
| if not (1 <= target_percentage <= 99): | |
| print("❌ Please enter a percentage between 1 and 99.") | |
| continue | |
| analyzer.matching_percentage = target_percentage | |
| break | |
| except ValueError: | |
| print("❌ Please enter a valid number.") | |
| continue | |
| # Step 5: Find similar sequences | |
| print(f"\n⏳ Analyzing sequences for {target_percentage}% similarity...") | |
| start_time = time.time() | |
| matched_ids, actual_percentage = analyzer.find_similar_sequences(target_percentage) | |
| if not matched_ids: | |
| print(f"❌ No similar sequences found at {target_percentage}% similarity.") | |
| print("💡 Try lowering the similarity percentage (e.g., 70-80%) to find more distant matches.") | |
| return | |
| analyzer.matched_sequences = matched_ids | |
| analyzer.actual_percentage = actual_percentage | |
| elapsed = time.time() - start_time | |
| print(f"✅ Similarity analysis completed in {elapsed:.1f} seconds") | |
| # Step 6: Build tree structure | |
| print("\n⏳ Building phylogenetic tree structure...") | |
| start_time = time.time() | |
| tree_structure = analyzer.build_tree_structure_with_ml_safe(matched_ids) | |
| if not tree_structure: | |
| print("❌ Failed to build tree structure.") | |
| return | |
| elapsed = time.time() - start_time | |
| print(f"✅ Tree structure built in {elapsed:.1f} seconds") | |
| # Step 7: Create visualization and save HTML | |
| print("\n⏳ Creating interactive visualization...") | |
| start_time = time.time() | |
| fig = analyzer.create_interactive_tree(matched_ids, actual_percentage) | |
| if fig: | |
| elapsed = time.time() - start_time | |
| print(f"✅ Visualization created in {elapsed:.1f} seconds") | |
| # Save the interactive HTML file | |
| html_filename = "phylogenetic_tree_interactive.html" | |
| fig.write_html(html_filename) | |
| print(f"📄 Interactive HTML saved: {html_filename}") | |
| print(f"\n🎉 Analysis completed successfully!") | |
| print(f" Query ID: {analyzer.query_id}") | |
| print(f" Query sequence length: {len(analyzer.query_sequence)} nucleotides") | |
| print(f" Similar sequences found: {len(matched_ids)}") | |
| print(f" Actual similarity percentage: {actual_percentage:.1f}%") | |
| print(f" HTML file generated: {html_filename}") | |
| else: | |
| print("❌ Visualization creation failed.") | |
| return | |
| except KeyboardInterrupt: | |
| print(f"\n\n⚠️ Analysis interrupted by user.") | |
| sys.exit(1) | |
| except Exception as e: | |
| print(f"\n❌ An error occurred during analysis: {e}") | |
| print(f"Please check your input data and try again.") | |
| sys.exit(1) | |
| def command_line_interface(): | |
| parser = argparse.ArgumentParser( | |
| description="Advanced Phylogenetic Tree Analyzer with AI-enhanced similarity matching", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # %(prog)s -d data.csv -q MH087032 -s 95 | |
| # %(prog)s -d data.csv -q MH087032 -s 90 --no-ai --batch query1,query2,query3 | |
| """ | |
| ) | |
| parser.add_argument('-d', '--data', required=True, | |
| help='Path to CSV data file') | |
| parser.add_argument('-q', '--query', required=True, | |
| help='Query sequence ID or nucleotide sequence') | |
| parser.add_argument('-s', '--similarity', type=float, default=95.0, | |
| help='Target similarity percentage (70-99, default: 95)') | |
| parser.add_argument('--no-ai', action='store_true', | |
| help='Skip AI model training') | |
| parser.add_argument('--batch', | |
| help='Comma-separated list of query IDs for batch processing') | |
| parser.add_argument('--output-dir', default='.', | |
| help='Output directory for results') | |
| parser.add_argument('--save-json', action='store_true', | |
| help='Save detailed results to JSON') | |
| args = parser.parse_args() | |
| # Validate arguments | |
| if not (70 <= args.similarity <= 99): | |
| print("❌ Similarity percentage must be between 70 and 99.") | |
| sys.exit(1) | |
| if not Path(args.data).exists(): | |
| print(f"❌ Data file not found: {args.data}") | |
| sys.exit(1) | |
| # Initialize analyzer | |
| analyzer = PhylogeneticTreeAnalyzer() | |
| # Load data | |
| if not analyzer.load_data(args.data): | |
| print("❌ Failed to load data.") | |
| sys.exit(1) | |
| # Train AI model (unless disabled) | |
| if not args.no_ai: | |
| print("\n⏳ Training AI model... This may take a few moments.", flush=True) | |
| start_time = time.time() | |
| if analyzer.train_ai_model(): | |
| elapsed = time.time() - start_time | |
| print(f"✅ AI model training completed in {elapsed:.1f} seconds", flush=True) | |
| else: | |
| print("⚠️ AI model training failed, continuing with basic analysis", flush=True) | |
| # Process queries | |
| queries = args.batch.split(',') if args.batch else [args.query] | |
| for query in queries: | |
| query = query.strip() | |
| print(f"\n🔍 Processing: {query}") | |
| if analyzer.find_query_sequence(query): | |
| matched_ids, actual_percentage = analyzer.find_similar_sequences(args.similarity) | |
| if matched_ids: | |
| analyzer.build_tree_structure_with_ml_safe(matched_ids) | |
| fig = analyzer.create_interactive_tree(matched_ids, actual_percentage) | |
| if fig: | |
| # Save the interactive HTML file | |
| html_filename = f"phylogenetic_tree_{query.replace('/', '_')}_interactive.html" | |
| fig.write_html(html_filename) | |
| print(f"📄 Interactive HTML saved: {html_filename}") | |
| print(f"✅ Analysis completed for {query}") | |
| else: | |
| print(f"❌ No similar sequences found for {query}") | |
| else: | |
| print(f"❌ Query not found: {query}") | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except KeyboardInterrupt: | |
| print(f"\n\n👋 Goodbye!") | |
| sys.exit(0) | |
| except Exception as e: | |
| print(f"\n❌ Unexpected error: {e}") | |
| sys.exit(1) | |
| #KR815908 |