re-type commited on
Commit
3668df7
·
verified ·
1 Parent(s): 2d487be

Upload ml_simplified_tree.py

Browse files
Files changed (1) hide show
  1. ml_simplified_tree.py +2027 -0
ml_simplified_tree.py ADDED
@@ -0,0 +1,2027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """ML simplified tree.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1LiDjip-h70ilIex9PedpWCZARWglija7
8
+ """
9
+
10
+ !pip install biopython plotly
11
+
12
+ # Commented out IPython magic to ensure Python compatibility.
13
+ import pandas as pd
14
+ import numpy as np
15
+ import plotly.graph_objects as go
16
+ import plotly.offline as pyo
17
+ from plotly.subplots import make_subplots
18
+ from Bio import Phylo, SeqIO, AlignIO
19
+ from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor
20
+ from Bio.Align import MultipleSeqAlignment
21
+ from Bio.Seq import Seq
22
+ from Bio.SeqRecord import SeqRecord
23
+ from sklearn.feature_extraction.text import CountVectorizer
24
+ from sklearn.metrics.pairwise import cosine_similarity
25
+ from sklearn.ensemble import RandomForestClassifier
26
+ from sklearn.model_selection import train_test_split
27
+ from sklearn.preprocessing import LabelEncoder
28
+ import warnings
29
+ import os
30
+ import sys
31
+ from typing import Dict, List, Tuple, Optional, Any
32
+ import json
33
+ import re
34
+ from scipy.optimize import minimize
35
+ from scipy.spatial.distance import pdist, squareform
36
+ from Bio.Phylo import BaseTree
37
+ import itertools
38
+ from collections import defaultdict, deque
39
+ import argparse
40
+ import time
41
+ from pathlib import Path
42
+
43
+ warnings.filterwarnings('ignore')
44
+
45
+ class PhylogeneticTreeAnalyzer:
46
+
47
+ def __init__(self):
48
+
49
+ self.data = None
50
+ self.query_sequence = None
51
+ self.query_id = None
52
+ self.matching_percentage = 95.0
53
+ self.actual_percentage = None
54
+ self.matched_sequences = []
55
+ self.tree_structure = {}
56
+ self.similarity_scores = {}
57
+ self.ai_model = None
58
+ self.label_encoder = LabelEncoder()
59
+ # ML-specific attributes
60
+ self.ml_tree = None
61
+ self.ml_alignment = None
62
+ self.ml_results = {}
63
+ self.horizontal_line_tracker = [] # Track horizontal lines with verticals
64
+ self.query_ml_group = None # Track which ML group contains the query
65
+ self.base_horizontal_length = 1.2 # Base length for horizontal lines
66
+
67
+ def load_data(self, data_file: str):
68
+
69
+ try:
70
+ self.data = pd.read_csv(data_file)
71
+ # required_columns = ['Accession Number', 'ML', 'Genotype', 'Host',
72
+ # 'Country', 'Isolate', 'Year', 'F-gene']
73
+
74
+ # missing_columns = [col for col in self.data.columns if col not in required_columns] # Corrected check for missing columns
75
+ # if missing_columns:
76
+ # print(f"Error: Missing required columns: {missing_columns}")
77
+ # return False
78
+
79
+ print(f"✓ Data loaded successfully: {len(self.data)} sequences")
80
+ print(f"✓ ML Groups found: {self.data['ML'].nunique()}")
81
+ print(f"✓ Genotypes found: {self.data['Genotype'].nunique()}")
82
+ return True
83
+
84
+ except Exception as e:
85
+ print(f"Error loading data: {e}")
86
+ return False
87
+
88
+
89
+ def calculate_f_gene_similarity(self, seq1: str, seq2: str) -> float:
90
+
91
+ try:
92
+ # Handle empty or None sequences
93
+ if not seq1 or not seq2:
94
+ return 0.0
95
+
96
+ # Convert to uppercase and remove non-nucleotide characters
97
+ seq1 = re.sub(r'[^ATGC]', '', str(seq1).upper())
98
+ seq2 = re.sub(r'[^ATGC]', '', str(seq2).upper())
99
+
100
+ if len(seq1) == 0 or len(seq2) == 0:
101
+ return 0.0
102
+
103
+ # Use k-mer analysis for similarity calculation
104
+ k = 5 # 5-mer analysis
105
+ kmers1 = set([seq1[i:i+k] for i in range(len(seq1)-k+1) if len(seq1[i:i+k]) == k])
106
+ kmers2 = set([seq2[i:i+k] for i in range(len(seq2)-k+1) if len(seq2[i:i+k]) == k])
107
+
108
+ if len(kmers1) == 0 and len(kmers2) == 0:
109
+ return 100.0
110
+ elif len(kmers1) == 0 or len(kmers2) == 0:
111
+ return 0.0
112
+
113
+ # Calculate Jaccard similarity
114
+ intersection = len(kmers1.intersection(kmers2))
115
+ union = len(kmers1.union(kmers2))
116
+ similarity = (intersection / union) * 100 if union > 0 else 0.0
117
+
118
+ return round(similarity, 2)
119
+
120
+ except Exception as e:
121
+ print(f"Error calculating similarity: {e}")
122
+ return 0.0
123
+
124
+ def train_ai_model(self):
125
+
126
+ try:
127
+
128
+ # Skip training if insufficient data
129
+ if len(self.data) < 10: # Require minimum 10 samples
130
+ print("⚠️ Insufficient data to train AI model (min 10 samples required)", flush=True)
131
+ return False
132
+
133
+ print("🤖 Training AI model for sequence analysis...", flush=True)
134
+
135
+ # Prepare features from F-gene sequences
136
+ f_gene_sequences = self.data['F-gene'].fillna('').astype(str)
137
+
138
+ # Create k-mer features (3-mers to 6-mers)
139
+ features = []
140
+ for seq in f_gene_sequences:
141
+ seq_clean = re.sub(r'[^ATGC]', '', seq.upper())
142
+ if len(seq_clean) < 3:
143
+ features.append([0] * 100) # Placeholder for short sequences
144
+ continue
145
+
146
+ feature_vector = []
147
+ # 3-mers
148
+ kmers_3 = [seq_clean[i:i+3] for i in range(len(seq_clean)-2)]
149
+ kmer_counts_3 = {kmer: kmers_3.count(kmer) for kmer in set(kmers_3)}
150
+
151
+ # 4-mers
152
+ kmers_4 = [seq_clean[i:i+4] for i in range(len(seq_clean)-3)]
153
+ kmer_counts_4 = {kmer: kmers_4.count(kmer) for kmer in set(kmers_4)}
154
+
155
+ # Create feature vector (top 50 3-mers + top 50 4-mers)
156
+ all_3mers = [''.join(p) for p in __import__('itertools').product('ATGC', repeat=3)]
157
+ all_4mers = [''.join(p) for p in __import__('itertools').product('ATGC', repeat=4)]
158
+
159
+ feature_vector.extend([kmer_counts_3.get(kmer, 0) for kmer in all_3mers[:50]])
160
+ feature_vector.extend([kmer_counts_4.get(kmer, 0) for kmer in all_4mers[:50]])
161
+
162
+ features.append(feature_vector)
163
+
164
+ # Prepare target labels (ML groups)
165
+ targets = self.label_encoder.fit_transform(self.data['ML'].fillna('Unknown'))
166
+
167
+ # Skip if only 1 class
168
+ if len(np.unique(targets)) < 2:
169
+ print("⚠️ Need at least 2 distinct classes for training", flush=True)
170
+ return False
171
+
172
+ # Train Random Forest model
173
+ X = np.array(features)
174
+ y = targets
175
+
176
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
177
+
178
+ self.ai_model = RandomForestClassifier(n_estimators=100, random_state=42)
179
+ self.ai_model.fit(X_train, y_train)
180
+
181
+ # Calculate accuracy
182
+ accuracy = self.ai_model.score(X_test, y_test)
183
+ print(f"✓ AI model trained successfully with accuracy: {accuracy:.2%}", flush=True)
184
+
185
+ return True
186
+
187
+ except Exception as e:
188
+ print(f"🚨 CRITICAL training error: {e}", flush=True)
189
+ import traceback
190
+ traceback.print_exc()
191
+ return False
192
+
193
+ def find_query_sequence(self, query_input: str) -> bool:
194
+ """
195
+ Modified to accept any sequence input, not just those existing in the dataset.
196
+ """
197
+ try:
198
+ # Check if input is an accession number from the dataset
199
+ if query_input in self.data['Accession Number'].values:
200
+ self.query_id = query_input
201
+ query_row = self.data[self.data['Accession Number'] == query_input].iloc[0]
202
+ self.query_sequence = query_row['F-gene']
203
+ print(f"✓ Query sequence found by ID: {query_input}")
204
+ return True
205
+
206
+ # Check if input is a nucleotide sequence
207
+ query_clean = re.sub(r'[^ATGC]', '', str(query_input).upper())
208
+
209
+ # Accept any sequence with reasonable length (even short ones for testing)
210
+ if len(query_clean) >= 10: # Minimum sequence length (reduced from 50)
211
+ # For sequences not in dataset, create a unique identifier
212
+ if query_input not in self.data['Accession Number'].values:
213
+ # Generate a unique query ID for novel sequences
214
+ self.query_id = f"QUERY_{hash(query_clean) % 100000:05d}"
215
+ self.query_sequence = query_clean
216
+ print(f"✓ Novel query sequence accepted with ID: {self.query_id}")
217
+ print(f" Sequence length: {len(query_clean)} nucleotides")
218
+ return True
219
+ else:
220
+ # If somehow it matches an accession but wasn't caught above
221
+ self.query_id = query_input
222
+ self.query_sequence = query_clean
223
+ print(f"✓ Query sequence processed: {query_input}")
224
+ return True
225
+
226
+ # If sequence is too short or invalid
227
+ if len(query_clean) < 10:
228
+ print(f"✗ Query sequence too short. Minimum length: 10 nucleotides (provided: {len(query_clean)})")
229
+ else:
230
+ print(f"✗ Invalid sequence format. Please provide nucleotides (A, T, G, C) or valid accession number")
231
+
232
+ return False
233
+
234
+ except Exception as e:
235
+ print(f"Error processing query sequence: {e}")
236
+ return False
237
+
238
+ def find_similar_sequences(self, target_percentage: float) -> Tuple[List[str], float]:
239
+ """
240
+ Modified to work with any query sequence, including novel ones not in the dataset.
241
+ """
242
+ try:
243
+ print(f"🔍 Finding sequences with {target_percentage}% similarity to query...")
244
+ similarities = []
245
+
246
+ # Calculate similarity between query and all sequences in dataset
247
+ for idx, row in self.data.iterrows():
248
+ # Skip if this is the same sequence (only relevant for existing accession numbers)
249
+ if hasattr(self, 'query_id') and row['Accession Number'] == self.query_id:
250
+ continue
251
+
252
+ try:
253
+ similarity = self.calculate_f_gene_similarity(self.query_sequence, row['F-gene'])
254
+ similarities.append({
255
+ 'id': row['Accession Number'],
256
+ 'similarity': similarity,
257
+ 'ml': row['ML'] if 'ML' in row else 'Unknown',
258
+ 'genotype': row['Genotype'] if 'Genotype' in row else 'Unknown'
259
+ })
260
+ except Exception as seq_error:
261
+ print(f"⚠ Skipping sequence {row['Accession Number']}: {seq_error}")
262
+ continue
263
+
264
+ if not similarities:
265
+ print("❌ No valid sequences found for comparison")
266
+ return [], target_percentage
267
+
268
+ # Sort by similarity (highest first)
269
+ similarities.sort(key=lambda x: x['similarity'], reverse=True)
270
+
271
+ # Find sequences within target percentage range (±2%)
272
+ target_range = 2.0
273
+ candidates = [s for s in similarities
274
+ if abs(s['similarity'] - target_percentage) <= target_range]
275
+
276
+ if not candidates:
277
+ # If no exact matches, find sequences with closest similarity
278
+ closest_sim = min(similarities, key=lambda x: abs(x['similarity'] - target_percentage))
279
+ actual_percentage = closest_sim['similarity']
280
+
281
+ # Get sequences within ±1% of the closest similarity
282
+ candidates = [s for s in similarities
283
+ if abs(s['similarity'] - actual_percentage) <= 1.0]
284
+
285
+ print(f"⚠ No sequences found at exactly {target_percentage}%. Using closest: {actual_percentage:.1f}%")
286
+ else:
287
+ actual_percentage = target_percentage
288
+
289
+ # Limit results to prevent overwhelming visualization (optional)
290
+ max_results = 50 # Adjust as needed
291
+ if len(candidates) > max_results:
292
+ candidates = candidates[:max_results]
293
+ print(f"⚠ Limited results to top {max_results} matches for better visualization")
294
+
295
+ # Store similarity scores for later use
296
+ self.similarity_scores = {} # Reset similarity scores
297
+ for candidate in candidates:
298
+ self.similarity_scores[candidate['id']] = candidate['similarity']
299
+
300
+ matched_ids = [c['id'] for c in candidates]
301
+
302
+ # Show some statistics
303
+ if similarities:
304
+ max_sim = max(similarities, key=lambda x: x['similarity'])['similarity']
305
+ min_sim = min(similarities, key=lambda x: x['similarity'])['similarity']
306
+ avg_sim = sum(s['similarity'] for s in similarities) / len(similarities)
307
+
308
+ print(f"✓ Found {len(matched_ids)} sequences at ~{actual_percentage:.1f}% similarity")
309
+ print(f" Similarity range in dataset: {min_sim:.1f}% - {max_sim:.1f}% (avg: {avg_sim:.1f}%)")
310
+
311
+ return matched_ids, actual_percentage
312
+
313
+ except Exception as e:
314
+ print(f"Error finding similar sequences: {e}")
315
+ return [], target_percentage
316
+
317
+
318
+ def build_tree_structure(self, matched_ids: List[str]) -> Dict:
319
+ try:
320
+ print("🌳 Building normalized horizontal tree structure...")
321
+
322
+ # Initialize tree structure
323
+ tree_structure = {
324
+ 'root': {
325
+ 'name': 'Root',
326
+ 'type': 'root',
327
+ 'children': {},
328
+ 'x': 0,
329
+ 'y': 0,
330
+ 'has_vertical_attachment': False,
331
+ 'extension_level': 0
332
+ }
333
+ }
334
+
335
+ # Group sequences by ML and Genotype
336
+ ml_groups = {}
337
+ for idx, row in self.data.iterrows():
338
+ ml_group = row['ML']
339
+ genotype = row['Genotype']
340
+ seq_id = row['Accession Number']
341
+
342
+ if ml_group not in ml_groups:
343
+ ml_groups[ml_group] = {}
344
+
345
+ if genotype not in ml_groups[ml_group]:
346
+ ml_groups[ml_group][genotype] = []
347
+
348
+ ml_groups[ml_group][genotype].append({
349
+ 'id': seq_id,
350
+ 'data': row.to_dict(),
351
+ 'is_query': seq_id == self.query_id,
352
+ 'is_matched': seq_id in matched_ids,
353
+ 'similarity': self.similarity_scores.get(seq_id, 0.0)
354
+ })
355
+
356
+ # Normalize ML group names and organize
357
+ normalized_ml_groups = self._normalize_ml_groups(ml_groups)
358
+
359
+ # Build normalized ML level - horizontal layout with progressive extensions
360
+ self._build_normalized_ml_nodes(tree_structure, normalized_ml_groups, matched_ids)
361
+
362
+ self.tree_structure = tree_structure
363
+ print(f"✓ Normalized horizontal tree structure built")
364
+ return tree_structure
365
+
366
+ except Exception as e:
367
+ print(f"Error building tree structure: {e}")
368
+ return {}
369
+
370
+ def _normalize_ml_groups(self, ml_groups: Dict) -> Dict:
371
+ """Normalize ML group names and organize hierarchically"""
372
+ try:
373
+ normalized_groups = {}
374
+
375
+ for ml_name, genotypes in ml_groups.items():
376
+ # Extract base ML name
377
+ if ml_name.startswith('UNCL'):
378
+ base_ml = 'UNCL'
379
+ elif '.' in ml_name and any(char.isdigit() for char in ml_name):
380
+ # For names like XII.1.2, XII.1, etc., extract the base (XII)
381
+ base_ml = ml_name.split('.')[0]
382
+ else:
383
+ base_ml = ml_name
384
+
385
+ # Initialize normalized group structure
386
+ if base_ml not in normalized_groups:
387
+ normalized_groups[base_ml] = {
388
+ 'full_ml_groups': {},
389
+ 'representative_sequences': [],
390
+ 'has_special_sequences': False
391
+ }
392
+
393
+ # Check if this ML group has query or matched sequences
394
+ has_special = any(
395
+ any(seq['is_query'] or seq['is_matched'] for seq in sequences)
396
+ for sequences in genotypes.values()
397
+ )
398
+
399
+ if has_special:
400
+ normalized_groups[base_ml]['has_special_sequences'] = True
401
+ normalized_groups[base_ml]['full_ml_groups'][ml_name] = genotypes
402
+ else:
403
+ # Add as representative (limit to 2 representatives)
404
+ if len(normalized_groups[base_ml]['representative_sequences']) < 2:
405
+ # Get 1-2 representative sequences from this ML group
406
+ for genotype, sequences in list(genotypes.items())[:2]:
407
+ if len(normalized_groups[base_ml]['representative_sequences']) < 2:
408
+ normalized_groups[base_ml]['representative_sequences'].extend(sequences[:1])
409
+
410
+ return normalized_groups
411
+
412
+ except Exception as e:
413
+ print(f"Error normalizing ML groups: {e}")
414
+ return {}
415
+
416
+ def _build_normalized_ml_nodes(self, tree_structure: Dict, normalized_ml_groups: Dict, matched_ids: List[str]):
417
+ """Build normalized ML nodes with equal spacing and progressive horizontal extensions"""
418
+ try:
419
+ # Reset horizontal line tracker
420
+ self.horizontal_line_tracker = []
421
+
422
+ # Identify which ML group contains the query
423
+ self._identify_query_ml_group(normalized_ml_groups)
424
+
425
+ # Calculate equal spacing for all ML groups
426
+ ml_positions = self._calculate_dynamic_ml_positions(normalized_ml_groups)
427
+
428
+ # Mark root as having vertical attachment if it has multiple children
429
+ root_has_vertical = len(normalized_ml_groups) > 1
430
+ tree_structure['root']['has_vertical_attachment'] = root_has_vertical
431
+
432
+ for ml_idx, (base_ml, ml_data) in enumerate(normalized_ml_groups.items()):
433
+ y_pos = ml_positions[ml_idx]
434
+
435
+ # Determine if this ML node will have vertical attachments
436
+ has_vertical = ml_data['has_special_sequences'] and len(ml_data['full_ml_groups']) > 1
437
+
438
+ # Check if this ML group contains the query
439
+ contains_query = (base_ml == self.query_ml_group)
440
+
441
+ # Calculate horizontal line length based on connections and query presence
442
+ horizontal_length = self._determine_horizontal_line_length(
443
+ 'normalized_ml_group', has_vertical, contains_query
444
+ )
445
+
446
+ x_pos = horizontal_length
447
+
448
+ # Create normalized ML node
449
+ tree_structure['root']['children'][base_ml] = {
450
+ 'name': base_ml,
451
+ 'type': 'normalized_ml_group',
452
+ 'children': {},
453
+ 'x': x_pos,
454
+ 'y': y_pos,
455
+ 'has_special_sequences': ml_data['has_special_sequences'],
456
+ 'has_vertical_attachment': has_vertical,
457
+ 'horizontal_line_length': horizontal_length,
458
+ 'contains_query': contains_query
459
+ }
460
+
461
+ if ml_data['has_special_sequences']:
462
+ # Build full ML nodes for groups with special sequences
463
+ self._build_full_ml_nodes(
464
+ tree_structure['root']['children'][base_ml],
465
+ ml_data['full_ml_groups'],
466
+ y_pos,
467
+ matched_ids,
468
+ x_pos
469
+ )
470
+ else:
471
+ # Add representative sequences directly
472
+ self._add_representative_sequences(
473
+ tree_structure['root']['children'][base_ml],
474
+ ml_data['representative_sequences'],
475
+ y_pos,
476
+ x_pos
477
+ )
478
+
479
+ except Exception as e:
480
+ print(f"Error building normalized ML nodes: {e}")
481
+
482
+ def _calculate_dynamic_ml_positions(self, normalized_ml_groups: Dict) -> List[float]:
483
+ """Calculate equal Y positions for all ML groups regardless of content"""
484
+ try:
485
+ ml_count = len(normalized_ml_groups)
486
+ if ml_count == 0:
487
+ return []
488
+
489
+ if ml_count == 1:
490
+ return [0.0]
491
+
492
+ # Equal spacing between all ML nodes
493
+ total_spacing = (ml_count - 1) * 2.0 # 2.0 units between each ML node
494
+ start_y = -total_spacing / 2
495
+
496
+ positions = []
497
+ for i in range(ml_count):
498
+ positions.append(start_y + i * 2.0)
499
+
500
+ return positions
501
+
502
+ except Exception as e:
503
+ print(f"Error calculating dynamic positions: {e}")
504
+ return list(range(len(normalized_ml_groups)))
505
+
506
+ def _build_full_ml_nodes(self, normalized_ml_node: Dict, full_ml_groups: Dict, base_y: float, matched_ids: List[str], parent_x: float):
507
+ """Build full ML nodes with genotypes for groups containing special sequences"""
508
+ try:
509
+ # Calculate equal positions for full ML groups
510
+ full_ml_positions = self._calculate_full_ml_positions(full_ml_groups, base_y)
511
+
512
+ for ml_idx, (full_ml_name, genotypes) in enumerate(full_ml_groups.items()):
513
+ y_pos = full_ml_positions[ml_idx]
514
+
515
+ # Determine if this full ML node will have vertical attachments
516
+ special_genotypes_count = sum(1 for genotype, sequences in genotypes.items()
517
+ if any(seq['is_query'] or seq['is_matched'] for seq in sequences))
518
+ has_vertical = special_genotypes_count > 1
519
+
520
+ # Check if this full ML group contains the query
521
+ contains_query = any(
522
+ any(seq['is_query'] for seq in sequences)
523
+ for sequences in genotypes.values()
524
+ )
525
+
526
+ # Calculate horizontal line length
527
+ horizontal_length = self._determine_horizontal_line_length(
528
+ 'full_ml_group', has_vertical, contains_query
529
+ )
530
+
531
+ x_pos = parent_x + horizontal_length
532
+
533
+ # Create full ML node
534
+ normalized_ml_node['children'][full_ml_name] = {
535
+ 'name': full_ml_name,
536
+ 'type': 'full_ml_group',
537
+ 'children': {},
538
+ 'x': x_pos,
539
+ 'y': y_pos,
540
+ 'sequences_count': sum(len(seqs) for seqs in genotypes.values()),
541
+ 'has_vertical_attachment': has_vertical,
542
+ 'horizontal_line_length': horizontal_length,
543
+ 'contains_query': contains_query
544
+ }
545
+
546
+ # Build genotype nodes
547
+ self._build_genotype_nodes(
548
+ normalized_ml_node['children'][full_ml_name],
549
+ genotypes,
550
+ y_pos,
551
+ matched_ids,
552
+ x_pos
553
+ )
554
+
555
+ except Exception as e:
556
+ print(f"Error building full ML nodes: {e}")
557
+
558
+ def _calculate_full_ml_positions(self, full_ml_groups: Dict, base_y: float) -> List[float]:
559
+ """Calculate equal positions for full ML groups"""
560
+ try:
561
+ ml_count = len(full_ml_groups)
562
+ if ml_count <= 1:
563
+ return [base_y]
564
+
565
+ # Equal spacing for full ML groups
566
+ spacing = 1.5 # Fixed spacing between full ML groups
567
+ start_y = base_y - (spacing * (ml_count - 1)) / 2
568
+
569
+ positions = []
570
+ for i in range(ml_count):
571
+ positions.append(start_y + i * spacing)
572
+
573
+ return positions
574
+
575
+ except Exception as e:
576
+ print(f"Error calculating full ML positions: {e}")
577
+ return [base_y] * len(full_ml_groups)
578
+
579
+ def _build_genotype_nodes(self, full_ml_node: Dict, genotypes: Dict, base_y: float, matched_ids: List[str], parent_x: float):
580
+ """Build genotype nodes with sequences - horizontal line length based on sequence count"""
581
+ try:
582
+ # Filter genotypes with special sequences
583
+ special_genotypes = []
584
+ for genotype, sequences in genotypes.items():
585
+ if any(seq['is_query'] or seq['is_matched'] for seq in sequences):
586
+ special_genotypes.append((genotype, sequences))
587
+
588
+ if not special_genotypes:
589
+ return
590
+
591
+ # Calculate equal genotype positions (vertical positioning remains equal)
592
+ genotype_positions = self._calculate_genotype_positions(special_genotypes, base_y)
593
+
594
+ # Calculate sequence counts for each genotype to determine horizontal line lengths
595
+ genotype_sequence_counts = []
596
+ for genotype, sequences in special_genotypes:
597
+ special_sequences = [seq for seq in sequences if seq['is_query'] or seq['is_matched']]
598
+ genotype_sequence_counts.append((genotype, sequences, len(special_sequences)))
599
+
600
+ for gt_idx, (genotype, sequences, sequence_count) in enumerate(genotype_sequence_counts):
601
+ y_pos = genotype_positions[gt_idx]
602
+
603
+ # Determine if this genotype will have vertical attachments
604
+ special_sequences = [seq for seq in sequences if seq['is_query'] or seq['is_matched']]
605
+ has_vertical = len(special_sequences) > 1
606
+
607
+ # Check if this genotype contains the query
608
+ contains_query = any(seq['is_query'] for seq in sequences)
609
+
610
+ # Calculate horizontal line length based on sequence count
611
+ horizontal_length = self._determine_genotype_horizontal_line_length(
612
+ sequence_count, has_vertical, contains_query
613
+ )
614
+
615
+ x_pos = parent_x + horizontal_length
616
+
617
+ # Create genotype node
618
+ full_ml_node['children'][genotype] = {
619
+ 'name': genotype,
620
+ 'type': 'genotype',
621
+ 'children': {},
622
+ 'x': x_pos,
623
+ 'y': y_pos,
624
+ 'sequences': sequences,
625
+ 'has_vertical_attachment': has_vertical,
626
+ 'horizontal_line_length': horizontal_length,
627
+ 'contains_query': contains_query,
628
+ 'sequence_count': sequence_count # Store for reference
629
+ }
630
+
631
+ # Add sequences horizontally
632
+ self._add_sequences_horizontal(
633
+ full_ml_node['children'][genotype],
634
+ sequences,
635
+ y_pos,
636
+ x_pos
637
+ )
638
+
639
+ except Exception as e:
640
+ print(f"Error building genotype nodes: {e}")
641
+
642
+ def _determine_genotype_horizontal_line_length(self, sequence_count: int, has_vertical: bool, contains_query: bool = False) -> float:
643
+ """Determine horizontal line length for genotype nodes based on sequence count"""
644
+ try:
645
+ base_length = self.base_horizontal_length
646
+
647
+ # Special case: Genotype containing query sequence gets additional length
648
+ query_bonus = 0.5 if contains_query else 0.0
649
+
650
+ # Calculate length based on sequence count
651
+ # More sequences = longer horizontal line
652
+ if sequence_count <= 1:
653
+ # Single sequence
654
+ length_multiplier = 1.0
655
+ elif sequence_count <= 3:
656
+ # 2-3 sequences
657
+ length_multiplier = 1.6
658
+ elif sequence_count <= 5:
659
+ # 4-5 sequences
660
+ length_multiplier = 2.3
661
+ elif sequence_count <= 8:
662
+ # 6-8 sequences
663
+ length_multiplier = 6.0
664
+ else:
665
+ # More than 8 sequences
666
+ length_multiplier = 6.0
667
+
668
+ # Calculate final length
669
+ calculated_length = base_length * length_multiplier + query_bonus
670
+
671
+ return calculated_length
672
+
673
+ except Exception as e:
674
+ print(f"Error determining genotype horizontal line length: {e}")
675
+ return self.base_horizontal_length
676
+
677
+ def _calculate_genotype_positions(self, special_genotypes: List, base_y: float) -> List[float]:
678
+ """Calculate equal positions for genotypes"""
679
+ try:
680
+ genotype_count = len(special_genotypes)
681
+ if genotype_count <= 1:
682
+ return [base_y]
683
+
684
+ # Equal spacing for genotypes
685
+ spacing = 1.0 # Fixed spacing between genotypes
686
+ start_y = base_y - (spacing * (genotype_count - 1)) / 2
687
+
688
+ positions = []
689
+ for i in range(genotype_count):
690
+ positions.append(start_y + i * spacing)
691
+
692
+ return positions
693
+
694
+ except Exception as e:
695
+ print(f"Error calculating genotype positions: {e}")
696
+ return [base_y] * len(special_genotypes)
697
+
698
+ def _add_representative_sequences(self, normalized_ml_node: Dict, representative_sequences: List[Dict], base_y: float, parent_x: float):
699
+ """Add representative sequences directly to normalized ML node"""
700
+ try:
701
+ if not representative_sequences:
702
+ return
703
+
704
+ # Calculate horizontal line length for representative sequences
705
+ # Representative sequences get a standard length (not similarity-based since they're not matched)
706
+ has_vertical = len(representative_sequences) > 1
707
+ horizontal_length = self._determine_horizontal_line_length('representative', has_vertical)
708
+ x_pos = parent_x + horizontal_length
709
+
710
+ if len(representative_sequences) == 1:
711
+ seq = representative_sequences[0]
712
+ normalized_ml_node['children'][f"{seq['id']}_rep"] = {
713
+ 'name': f"{seq['id']} (Rep)",
714
+ 'type': 'representative_sequence',
715
+ 'data': seq,
716
+ 'x': x_pos,
717
+ 'y': base_y,
718
+ 'has_vertical_attachment': False,
719
+ 'horizontal_line_length': horizontal_length
720
+ }
721
+ else:
722
+ # Equal spacing for multiple representative sequences
723
+ positions = self._calculate_sequence_positions(representative_sequences, base_y)
724
+
725
+ for idx, seq in enumerate(representative_sequences):
726
+ normalized_ml_node['children'][f"{seq['id']}_rep"] = {
727
+ 'name': f"{seq['id']} (Rep)",
728
+ 'type': 'representative_sequence',
729
+ 'data': seq,
730
+ 'x': x_pos,
731
+ 'y': positions[idx],
732
+ 'has_vertical_attachment': False,
733
+ 'horizontal_line_length': horizontal_length
734
+ }
735
+
736
+ except Exception as e:
737
+ print(f"Error adding representative sequences: {e}")
738
+
739
+ def _add_sequences_horizontal(self, genotype_node: Dict, sequences: List[Dict], base_y: float, parent_x: float):
740
+ """Add sequences horizontally with similarity-based line lengths"""
741
+ try:
742
+ # Define the query line length as the reference (100%)
743
+ query_line_length = 3.0 # Base length for query sequence (100%)
744
+
745
+ # Separate query and matched sequences
746
+ query_sequences = [seq for seq in sequences if seq['is_query']]
747
+ matched_sequences = [seq for seq in sequences if seq['is_matched'] and not seq['is_query']]
748
+
749
+ all_special_sequences = query_sequences + matched_sequences
750
+
751
+ if len(all_special_sequences) == 1:
752
+ # Single sequence - direct line with similarity-based length
753
+ sequence = all_special_sequences[0]
754
+ line_length = self._calculate_similarity_based_line_length(sequence, query_line_length)
755
+ x_pos = parent_x + line_length
756
+
757
+ genotype_node['children'][sequence['id']] = {
758
+ 'name': f"{sequence['id']}{' (' + str(sequence['similarity']) + '%)' if sequence['is_matched'] else ''}",
759
+ 'type': 'sequence',
760
+ 'data': sequence,
761
+ 'x': x_pos,
762
+ 'y': base_y,
763
+ 'has_vertical_attachment': False,
764
+ 'similarity_line_length': line_length
765
+ }
766
+ else:
767
+ # Multiple sequences - equal vertical distribution with similarity-based horizontal lengths
768
+ sequence_positions = self._calculate_sequence_positions(all_special_sequences, base_y)
769
+
770
+ for seq_idx, sequence in enumerate(all_special_sequences):
771
+ line_length = self._calculate_similarity_based_line_length(sequence, query_line_length)
772
+ x_pos = parent_x + line_length
773
+
774
+ genotype_node['children'][sequence['id']] = {
775
+ 'name': f"{sequence['id']}{' (' + str(sequence['similarity']) + '%)' if sequence['is_matched'] else ''}",
776
+ 'type': 'sequence',
777
+ 'data': sequence,
778
+ 'x': x_pos,
779
+ 'y': sequence_positions[seq_idx],
780
+ 'has_vertical_attachment': False,
781
+ 'similarity_line_length': line_length
782
+ }
783
+ except Exception as e:
784
+ print(f"Error adding sequences horizontally: {e}")
785
+
786
+ def _calculate_similarity_based_line_length(self, sequence: Dict, query_line_length: float) -> float:
787
+ """Calculate line length based on similarity percentage relative to query"""
788
+ try:
789
+ if sequence['is_query']:
790
+ # Query sequence gets 100% length
791
+ return query_line_length
792
+ elif sequence['is_matched']:
793
+ # Matched sequences get length proportional to their similarity
794
+ similarity = sequence['similarity']
795
+ # Convert similarity percentage to proportional length
796
+ proportional_length = (similarity / 100.0) * query_line_length
797
+ # Ensure minimum length for visibility
798
+ min_length = query_line_length * 0.2 # Minimum 20% of query length
799
+ return max(proportional_length, min_length)
800
+ else:
801
+ # Other sequences get a standard length (50% of query)
802
+ return query_line_length * 0.5
803
+ except Exception as e:
804
+ print(f"Error calculating similarity-based line length: {e}")
805
+ return query_line_length * 0.5
806
+
807
+
808
+ def _calculate_sequence_positions(self, sequences: List[Dict], base_y: float) -> List[float]:
809
+ """Calculate equal positions for sequences"""
810
+ try:
811
+ seq_count = len(sequences)
812
+ if seq_count <= 1:
813
+ return [base_y]
814
+
815
+ # Equal spacing for sequences
816
+ spacing = 0.8 # Fixed spacing between sequences
817
+ start_y = base_y - (spacing * (seq_count - 1)) / 2
818
+
819
+ positions = []
820
+ for i in range(seq_count):
821
+ positions.append(start_y + i * spacing)
822
+
823
+ return positions
824
+
825
+ except Exception as e:
826
+ print(f"Error calculating sequence positions: {e}")
827
+ return [base_y] * len(sequences)
828
+
829
+ def _determine_horizontal_line_length(self, node_type: str, has_vertical: bool, contains_query: bool = False) -> float:
830
+ """Determine horizontal line length based on node type and connections"""
831
+ try:
832
+ base_length = self.base_horizontal_length
833
+
834
+ # Special case: ML group containing query sequence gets much longer line
835
+ if contains_query and node_type == 'normalized_ml_group':
836
+ return base_length * 2.5 # Much longer for query ML group
837
+
838
+ # If this node has a vertical line attachment (connects to multiple children)
839
+ if has_vertical:
840
+ # Find the current longest horizontal line with vertical
841
+ current_max = base_length
842
+ for tracked_length in self.horizontal_line_tracker:
843
+ if tracked_length > current_max:
844
+ current_max = tracked_length
845
+
846
+ # Make this line incrementally longer
847
+ new_length = current_max + 0.3
848
+ self.horizontal_line_tracker.append(new_length)
849
+ return new_length
850
+ else:
851
+ # Direct connection (no vertical), use base length
852
+ return base_length
853
+
854
+ except Exception as e:
855
+ print(f"Error determining horizontal line length: {e}")
856
+ return self.base_horizontal_length
857
+
858
+ def _identify_query_ml_group(self, normalized_ml_groups: Dict):
859
+ """Identify which ML group contains the query sequence"""
860
+ try:
861
+ for base_ml, ml_data in normalized_ml_groups.items():
862
+ if ml_data['has_special_sequences']:
863
+ for full_ml_name, genotypes in ml_data['full_ml_groups'].items():
864
+ for genotype, sequences in genotypes.items():
865
+ if any(seq['is_query'] for seq in sequences):
866
+ self.query_ml_group = base_ml
867
+ return
868
+ except Exception as e:
869
+ print(f"Error identifying query ML group: {e}")
870
+
871
+ def _identify_query_ml_group(self, normalized_ml_groups: Dict):
872
+ """Identify which ML group contains the query sequence"""
873
+ try:
874
+ for base_ml, ml_data in normalized_ml_groups.items():
875
+ if ml_data['has_special_sequences']:
876
+ for full_ml_name, genotypes in ml_data['full_ml_groups'].items():
877
+ for genotype, sequences in genotypes.items():
878
+ if any(seq['is_query'] for seq in sequences):
879
+ self.query_ml_group = base_ml
880
+ return
881
+ except Exception as e:
882
+ print(f"Error identifying query ML group: {e}")
883
+
884
+ def _calculate_sequence_x_position_horizontal(self, sequence: Dict, max_similarity: float) -> float:
885
+ """Calculate X position based on similarity percentage for horizontal layout"""
886
+ # This function is now replaced by _calculate_similarity_based_line_length
887
+ # Keeping for backward compatibility, but the new approach is used in _add_sequences_horizontal
888
+
889
+ base_x = 0 # Relative to parent genotype node
890
+ query_line_length = 3.0 # Reference length for query (100%)
891
+
892
+ if sequence['is_query']:
893
+ return base_x + query_line_length # 100% length for query
894
+ elif sequence['is_matched']:
895
+ # Line length varies based on similarity percentage
896
+ similarity = sequence['similarity']
897
+ proportional_length = (similarity / 100.0) * query_line_length
898
+ min_length = query_line_length * 0.2 # Minimum 20% of query length
899
+ return base_x + max(proportional_length, min_length)
900
+ else:
901
+ return base_x + (query_line_length * 0.5) # 50% length for other sequences
902
+
903
+
904
+ def create_interactive_tree(self, matched_ids: List[str], actual_percentage: float):
905
+ try:
906
+ print("🎨 Creating horizontal interactive tree visualization...")
907
+
908
+ # Prepare data for plotting
909
+ edge_x = []
910
+ edge_y = []
911
+ node_x = []
912
+ node_y = []
913
+ node_colors = []
914
+ node_text = []
915
+ node_hover = []
916
+ node_sizes = []
917
+
918
+ # Updated color scheme for new node types
919
+ colors = {
920
+ 'root': '#FF0000', # Red for root
921
+ 'normalized_ml_group': '#FFB6C1', # Light pink for normalized ML groups
922
+ 'full_ml_group': '#FF69B4', # Hot pink for full ML groups
923
+ 'genotype': '#FFD700', # Gold for genotypes
924
+ 'representative_sequence': '#FFA500', # Orange for representative sequences
925
+ 'query_sequence': '#4B0082', # Dark purple for query
926
+ 'matched_sequence': '#6A5ACD', # Slate blue for matched
927
+ 'other_sequence': '#87CEEB' # Sky blue for others
928
+ }
929
+
930
+ def add_horizontal_edges(parent_x, parent_y, children_dict):
931
+ """Add horizontal connecting lines with proper vertical line sizing"""
932
+ if not children_dict:
933
+ return
934
+
935
+ children_list = list(children_dict.values())
936
+
937
+ if len(children_list) == 1:
938
+ # Single child - direct horizontal line
939
+ child = children_list[0]
940
+ edge_x.extend([parent_x, child['x'], None])
941
+ edge_y.extend([parent_y, child['y'], None])
942
+ else:
943
+ # Multiple children - horizontal line with vertical distribution
944
+ # Calculate the intermediate x position (where vertical line will be)
945
+ child_x_positions = [child['x'] for child in children_list]
946
+ min_child_x = min(child_x_positions)
947
+ intermediate_x = parent_x + (min_child_x - parent_x) * 0.8 # 80% of the way to nearest child
948
+
949
+ # Horizontal line to intermediate point
950
+ edge_x.extend([parent_x, intermediate_x, None])
951
+ edge_y.extend([parent_y, parent_y, None])
952
+
953
+ # Calculate vertical line range to fit exactly all children
954
+ child_y_positions = [child['y'] for child in children_list]
955
+ min_y, max_y = min(child_y_positions), max(child_y_positions)
956
+
957
+ # Vertical line sized exactly to fit all children
958
+ edge_x.extend([intermediate_x, intermediate_x, None])
959
+ edge_y.extend([min_y, max_y, None])
960
+
961
+ # Horizontal lines from vertical line to each child
962
+ for child in children_list:
963
+ edge_x.extend([intermediate_x, child['x'], None])
964
+ edge_y.extend([child['y'], child['y'], None])
965
+
966
+ def get_node_color_and_size(node):
967
+ """Determine node color and size based on type and content"""
968
+ if node['type'] == 'sequence':
969
+ if node['data']['is_query']:
970
+ return colors['query_sequence'], 10 # Reduced size for compactness
971
+ elif node['data']['is_matched']:
972
+ return colors['matched_sequence'], 8
973
+ else:
974
+ return colors['other_sequence'], 6
975
+ elif node['type'] == 'representative_sequence':
976
+ return colors['representative_sequence'], 7
977
+ elif node['type'] == 'normalized_ml_group':
978
+ # Larger size if it has special sequences
979
+ size = 9 if node.get('has_special_sequences', False) else 7
980
+ return colors['normalized_ml_group'], size
981
+ elif node['type'] == 'full_ml_group':
982
+ return colors['full_ml_group'], 8
983
+ elif node['type'] == 'genotype':
984
+ return colors['genotype'], 7
985
+ else:
986
+ return colors.get(node['type'], '#000000'), 7
987
+
988
+ def create_node_text(node):
989
+ """Create appropriate text label for each node type"""
990
+ if node['type'] == 'sequence':
991
+ if node['data']['is_matched'] and not node['data']['is_query']:
992
+ return f"{node['name']}"
993
+ else:
994
+ return node['name']
995
+ elif node['type'] == 'representative_sequence':
996
+ return node['name']
997
+ elif node['type'] == 'normalized_ml_group':
998
+ # Add indicator if it has special sequences
999
+ suffix = " *" if node.get('has_special_sequences', False) else ""
1000
+ return f"{node['name']}{suffix}"
1001
+ else:
1002
+ return node['name']
1003
+
1004
+ def create_hover_text(node):
1005
+ """Create detailed hover text for each node type"""
1006
+ if node['type'] == 'sequence':
1007
+ data = node['data']['data']
1008
+ hover_text = (
1009
+ f"<b>{node['name']}</b><br>"
1010
+ f"Type: {'Query Sequence' if node['data']['is_query'] else 'Matched Sequence' if node['data']['is_matched'] else 'Other Sequence'}<br>"
1011
+ f"ML Group: {data.get('ML', 'N/A')}<br>"
1012
+ f"Genotype: {data.get('Genotype', 'N/A')}<br>"
1013
+ f"Host: {data.get('Host', 'N/A')}<br>"
1014
+ f"Country: {data.get('Country', 'N/A')}<br>"
1015
+ f"Isolate: {data.get('Isolate', 'N/A')}<br>"
1016
+ f"Year: {data.get('Year', 'N/A')}"
1017
+ )
1018
+ if node['data']['is_matched']:
1019
+ hover_text += f"<br><b>Similarity: {node['data']['similarity']}%</b>"
1020
+ elif node['type'] == 'representative_sequence':
1021
+ data = node['data']['data']
1022
+ hover_text = (
1023
+ f"<b>{node['name']}</b><br>"
1024
+ f"Type: Representative Sequence<br>"
1025
+ f"ML Group: {data.get('ML', 'N/A')}<br>"
1026
+ f"Genotype: {data.get('Genotype', 'N/A')}<br>"
1027
+ f"Host: {data.get('Host', 'N/A')}<br>"
1028
+ f"Country: {data.get('Country', 'N/A')}"
1029
+ )
1030
+ elif node['type'] == 'normalized_ml_group':
1031
+ hover_text = f"<b>{node['name']}</b><br>Type: Normalized ML Group"
1032
+ if node.get('has_special_sequences', False):
1033
+ hover_text += "<br>Contains query/matched sequences"
1034
+ else:
1035
+ hover_text += "<br>Representative sequences only"
1036
+ elif node['type'] == 'full_ml_group':
1037
+ hover_text = f"<b>{node['name']}</b><br>Type: Full ML Group"
1038
+ if 'sequences_count' in node:
1039
+ hover_text += f"<br>Total Sequences: {node['sequences_count']}"
1040
+ elif node['type'] == 'genotype':
1041
+ hover_text = f"<b>{node['name']}</b><br>Type: Genotype"
1042
+ if 'sequences' in node:
1043
+ special_count = sum(1 for seq in node['sequences'] if seq['is_query'] or seq['is_matched'])
1044
+ hover_text += f"<br>Special Sequences: {special_count}/{len(node['sequences'])}"
1045
+ else:
1046
+ hover_text = f"<b>{node['name']}</b><br>Type: {node['type'].replace('_', ' ').title()}"
1047
+
1048
+ return hover_text
1049
+
1050
+ def add_node_and_edges(node, parent_x=None, parent_y=None):
1051
+ """Recursively add nodes and edges to the plot with equal spacing structure."""
1052
+ x, y = node['x'], node['y']
1053
+ node_x.append(x)
1054
+ node_y.append(y)
1055
+
1056
+ # Get node color and size
1057
+ color, size = get_node_color_and_size(node)
1058
+ node_colors.append(color)
1059
+ node_sizes.append(size)
1060
+
1061
+ # Create node text and hover
1062
+ node_text.append(create_node_text(node))
1063
+ node_hover.append(create_hover_text(node))
1064
+
1065
+ # Process children with equal spacing structure
1066
+ if 'children' in node and node['children']:
1067
+ add_horizontal_edges(x, y, node['children'])
1068
+ for child in node['children'].values():
1069
+ add_node_and_edges(child, x, y)
1070
+
1071
+ # Build the plot data starting from root
1072
+ root_node = self.tree_structure['root']
1073
+ add_node_and_edges(root_node)
1074
+
1075
+ # Add horizontal edges for root level
1076
+ if root_node['children']:
1077
+ add_horizontal_edges(root_node['x'], root_node['y'], root_node['children'])
1078
+
1079
+ # Create the figure
1080
+ fig = go.Figure()
1081
+
1082
+ # Add edges
1083
+ fig.add_trace(go.Scatter(
1084
+ x=edge_x, y=edge_y,
1085
+ mode='lines',
1086
+ line=dict(width=1, color='gray', dash='solid'), # Thinner lines for compactness
1087
+ hoverinfo='none',
1088
+ showlegend=False,
1089
+ name='Edges'
1090
+ ))
1091
+
1092
+ # Add nodes
1093
+ fig.add_trace(go.Scatter(
1094
+ x=node_x, y=node_y,
1095
+ mode='markers+text',
1096
+ marker=dict(
1097
+ size=node_sizes,
1098
+ color=node_colors,
1099
+ line=dict(width=1, color='black'), # Thinner borders
1100
+ opacity=0.85
1101
+ ),
1102
+ text=node_text,
1103
+ textposition="middle right",
1104
+ textfont=dict(size=9, color="black"), # Smaller font for compactness
1105
+ hoverinfo='text',
1106
+ hovertext=node_hover,
1107
+ showlegend=False,
1108
+ name='Nodes'
1109
+ ))
1110
+
1111
+ # Calculate proper layout dimensions to ensure everything fits
1112
+ if node_x and node_y:
1113
+ # Get the actual data bounds
1114
+ min_x, max_x = min(node_x), max(node_x)
1115
+ min_y, max_y = min(node_y), max(node_y)
1116
+
1117
+ # Calculate ranges
1118
+ x_range = max_x - min_x
1119
+ y_range = max_y - min_y
1120
+
1121
+ # Add padding to ensure nothing is cut off (20% padding on each side)
1122
+ x_padding = x_range * 0.2 if x_range > 0 else 1
1123
+ y_padding = y_range * 0.2 if y_range > 0 else 1
1124
+
1125
+ # Set axis ranges with padding
1126
+ x_axis_range = [min_x - x_padding, max_x + x_padding]
1127
+ y_axis_range = [min_y - y_padding, max_y + y_padding]
1128
+
1129
+ # Compact but sufficient sizing
1130
+ width = min(1400, max(800, int(x_range * 80 + 400))) # Cap max width
1131
+ height = min(900, max(500, int(y_range * 40 + 300))) # Cap max height
1132
+ else:
1133
+ width, height = 800, 500
1134
+ x_axis_range = None
1135
+ y_axis_range = None
1136
+
1137
+ # Update layout for compact horizontal tree with proper bounds
1138
+ fig.update_layout(
1139
+ title=dict(
1140
+ text=f"Compact Horizontal Phylogenetic Tree (ML-Based)<br>"
1141
+ f"Query: {self.query_id} | Similarity: {actual_percentage}% | "
1142
+ f"Matched: {len(matched_ids)}",
1143
+ x=0.5,
1144
+ font=dict(size=12) # Smaller title for compactness
1145
+ ),
1146
+ xaxis=dict(
1147
+ showgrid=False,
1148
+ gridcolor='lightgray',
1149
+ gridwidth=0.3, # Very thin grid lines
1150
+ zeroline=False,
1151
+ showticklabels=False,
1152
+ range=x_axis_range, # Set explicit range to prevent cutoff
1153
+ fixedrange=False, # Allow zooming if needed
1154
+ automargin=True # Automatically adjust margins
1155
+ ),
1156
+ yaxis=dict(
1157
+ showgrid=False,
1158
+ gridcolor='lightgray',
1159
+ gridwidth=0.3, # Very thin grid lines
1160
+ zeroline=False,
1161
+ showticklabels=False,
1162
+ range=y_axis_range, # Set explicit range to prevent cutoff
1163
+ fixedrange=False, # Allow zooming if needed
1164
+ automargin=True # Automatically adjust margins
1165
+ ),
1166
+ plot_bgcolor="white",
1167
+ paper_bgcolor="white",
1168
+ hovermode="closest",
1169
+ width=width,
1170
+ height=height,
1171
+ margin=dict(l=20, r=100, t=40, b=10), # Adequate margins, extra right margin for text
1172
+ autosize=False, # Don't auto-resize
1173
+ showlegend=True,
1174
+ legend=dict(
1175
+ x=1.02, # Position legend outside plot area
1176
+ y=1,
1177
+ xanchor='left',
1178
+ yanchor='top',
1179
+ bgcolor='rgba(255,255,255,0.8)',
1180
+ bordercolor='gray',
1181
+ borderwidth=1,
1182
+ font=dict(size=10) # Smaller legend font
1183
+ )
1184
+ )
1185
+
1186
+ # Add comprehensive legend with smaller markers
1187
+ legend_elements = [
1188
+ dict(name="Root", marker=dict(color=colors['root'], size=8)),
1189
+ dict(name="Normalized ML Groups", marker=dict(color=colors['normalized_ml_group'], size=8)),
1190
+ dict(name="Full ML Groups", marker=dict(color=colors['full_ml_group'], size=8)),
1191
+ dict(name="Genotypes", marker=dict(color=colors['genotype'], size=8)),
1192
+ dict(name="Query Sequence", marker=dict(color=colors['query_sequence'], size=10)),
1193
+ dict(name="Similar Sequences", marker=dict(color=colors['matched_sequence'], size=9)),
1194
+ dict(name="Representative Sequences", marker=dict(color=colors['representative_sequence'], size=8)),
1195
+ dict(name="Other Sequences", marker=dict(color=colors['other_sequence'], size=7))
1196
+ ]
1197
+
1198
+ for i, element in enumerate(legend_elements):
1199
+ fig.add_trace(go.Scatter(
1200
+ x=[None], y=[None],
1201
+ mode='markers',
1202
+ marker=element['marker'],
1203
+ name=element['name'],
1204
+ showlegend=True
1205
+ ))
1206
+
1207
+
1208
+ # Configure modebar for better user experience
1209
+ config = {
1210
+ 'displayModeBar': True,
1211
+ 'displaylogo': False,
1212
+ 'modeBarButtonsToRemove': ['select2d', 'lasso2d'],
1213
+ 'toImageButtonOptions': {
1214
+ 'format': 'png',
1215
+ 'filename': 'phylogenetic_tree',
1216
+ 'height': height,
1217
+ 'width': width,
1218
+ 'scale': 2
1219
+ }
1220
+ }
1221
+
1222
+ # Save outputs
1223
+ try:
1224
+ fig.write_html("phylogenetic_tree_normalized_horizontal.html", config=config)
1225
+ print("✓ Compact horizontal interactive tree saved as 'phylogenetic_tree_normalized_horizontal.html'")
1226
+ except Exception as e:
1227
+ print(f"Warning: Could not save HTML file: {e}")
1228
+
1229
+ # Display the figure with config
1230
+ try:
1231
+ fig.show(config=config)
1232
+ except Exception as e:
1233
+ print(f"Warning: Could not display figure: {e}")
1234
+
1235
+ return fig
1236
+
1237
+ except Exception as e:
1238
+ print(f"Error creating compact horizontal interactive tree: {e}")
1239
+ return None
1240
+
1241
+
1242
+ def create_sequence_alignment(self, sequence_ids: List[str]) -> Optional[MultipleSeqAlignment]:
1243
+
1244
+ try:
1245
+ print("🧬 Creating multiple sequence alignment...")
1246
+
1247
+ # Get sequences
1248
+ sequences = []
1249
+ for seq_id in sequence_ids:
1250
+ try:
1251
+ row = self.data[self.data['Accession Number'] == seq_id]
1252
+ if not row.empty:
1253
+ f_gene = str(row.iloc[0]['F-gene'])
1254
+ # Clean sequence (remove non-nucleotide characters)
1255
+ clean_seq = re.sub(r'[^ATGCN-]', '', f_gene.upper())
1256
+ if len(clean_seq) > 10: # Minimum sequence length
1257
+ seq_record = SeqRecord(Seq(clean_seq), id=seq_id, description="")
1258
+ sequences.append(seq_record)
1259
+ except Exception as e:
1260
+ print(f"Warning: Skipping sequence {seq_id}: {e}")
1261
+ continue
1262
+
1263
+ if len(sequences) < 2:
1264
+ print("❌ Need at least 2 valid sequences for alignment")
1265
+ return None
1266
+
1267
+ # Simple alignment (you might want to use MUSCLE or CLUSTAL for better results)
1268
+ aligned_sequences = self._simple_alignment(sequences)
1269
+
1270
+ print(f"✓ Alignment created with {len(aligned_sequences)} sequences")
1271
+ return MultipleSeqAlignment(aligned_sequences)
1272
+
1273
+ except Exception as e:
1274
+ print(f"Error creating alignment: {e}")
1275
+ return None
1276
+
1277
+ def _simple_alignment(self, sequences: List[SeqRecord]) -> List[SeqRecord]:
1278
+
1279
+ try:
1280
+ # Find maximum length
1281
+ max_length = max(len(seq.seq) for seq in sequences)
1282
+
1283
+ # Cap maximum length to prevent memory issues
1284
+ if max_length > 10000:
1285
+ max_length = 10000
1286
+ print(f"Warning: Sequences truncated to {max_length} bp")
1287
+
1288
+ # Pad sequences to same length
1289
+ aligned_sequences = []
1290
+ for seq in sequences:
1291
+ seq_str = str(seq.seq)[:max_length] # Truncate if too long
1292
+
1293
+ if len(seq_str) < max_length:
1294
+ # Pad with gaps at the end
1295
+ padded_seq = seq_str + '-' * (max_length - len(seq_str))
1296
+ else:
1297
+ padded_seq = seq_str
1298
+
1299
+ aligned_sequences.append(SeqRecord(Seq(padded_seq), id=seq.id, description=seq.description))
1300
+
1301
+ return aligned_sequences
1302
+ except Exception as e:
1303
+ print(f"Error in simple alignment: {e}")
1304
+ return sequences # Return original sequences as fallback
1305
+
1306
+ def calculate_ml_distances(self, alignment: MultipleSeqAlignment) -> np.ndarray:
1307
+
1308
+ try:
1309
+ print("📊 Calculating ML distances...")
1310
+
1311
+ # Convert alignment to numeric matrix
1312
+ seq_matrix = self._alignment_to_matrix(alignment)
1313
+ n_sequences = len(alignment)
1314
+
1315
+ if n_sequences == 0:
1316
+ return np.array([])
1317
+
1318
+ # Initialize distance matrix
1319
+ distance_matrix = np.zeros((n_sequences, n_sequences))
1320
+
1321
+ # Calculate pairwise ML distances
1322
+ for i in range(n_sequences):
1323
+ for j in range(i + 1, n_sequences):
1324
+ try:
1325
+ ml_distance = self._calculate_ml_distance_pair(seq_matrix[i], seq_matrix[j])
1326
+ distance_matrix[i][j] = ml_distance
1327
+ distance_matrix[j][i] = ml_distance
1328
+ except Exception as e:
1329
+ print(f"Warning: Error calculating distance between sequences {i} and {j}: {e}")
1330
+ # Use maximum distance as fallback
1331
+ distance_matrix[i][j] = 1.0
1332
+ distance_matrix[j][i] = 1.0
1333
+
1334
+ print("✓ ML distances calculated")
1335
+ return distance_matrix
1336
+
1337
+ except Exception as e:
1338
+ print(f"Error calculating ML distances: {e}")
1339
+ return np.array([])
1340
+
1341
+ def _alignment_to_matrix(self, alignment: MultipleSeqAlignment) -> np.ndarray:
1342
+
1343
+ try:
1344
+ # Nucleotide to number mapping
1345
+ nucleotide_map = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4, '-': 5}
1346
+
1347
+ matrix = []
1348
+ for record in alignment:
1349
+ sequence = str(record.seq).upper()
1350
+ numeric_seq = [nucleotide_map.get(nuc, 4) for nuc in sequence]
1351
+ matrix.append(numeric_seq)
1352
+
1353
+ return np.array(matrix)
1354
+ except Exception as e:
1355
+ print(f"Error converting alignment to matrix: {e}")
1356
+ return np.array([])
1357
+
1358
+ def _calculate_ml_distance_pair(self, seq1: np.ndarray, seq2: np.ndarray) -> float:
1359
+
1360
+ try:
1361
+ if len(seq1) == 0 or len(seq2) == 0:
1362
+ return 1.0
1363
+
1364
+ # Count differences (excluding gaps and N's)
1365
+ valid_positions = (seq1 < 4) & (seq2 < 4) # Exclude N's and gaps
1366
+
1367
+ if np.sum(valid_positions) == 0:
1368
+ return 1.0 # Maximum distance if no valid comparisons
1369
+
1370
+ differences = np.sum(seq1[valid_positions] != seq2[valid_positions])
1371
+ total_valid = np.sum(valid_positions)
1372
+
1373
+ if total_valid == 0:
1374
+ return 1.0
1375
+
1376
+ # Calculate proportion of differences
1377
+ p = differences / total_valid
1378
+
1379
+ # Jukes-Cantor correction
1380
+ if p >= 0.75:
1381
+ return 1.0 # Maximum distance
1382
+
1383
+ # JC distance formula: -3/4 * ln(1 - 4p/3)
1384
+ try:
1385
+ jc_distance = -0.75 * np.log(1 - (4 * p / 3))
1386
+ return min(max(jc_distance, 0.0), 1.0) # Clamp between 0 and 1
1387
+ except (ValueError, RuntimeWarning):
1388
+ return 1.0 # Return maximum distance if log calculation fails
1389
+
1390
+ except Exception as e:
1391
+ return 1.0 # Return maximum distance on error
1392
+
1393
+ def construct_ml_tree(self, alignment: MultipleSeqAlignment) -> Optional[BaseTree.Tree]:
1394
+
1395
+ try:
1396
+ print("🌳 Constructing Maximum Likelihood tree...")
1397
+
1398
+ # Calculate ML distance matrix
1399
+ distance_matrix = self.calculate_ml_distances(alignment)
1400
+
1401
+ if distance_matrix.size == 0:
1402
+ return None
1403
+
1404
+ # Create sequence names list
1405
+ sequence_names = [record.id for record in alignment]
1406
+
1407
+ # Build tree using neighbor-joining on ML distances
1408
+ tree = self._build_nj_tree_from_distances(distance_matrix, sequence_names)
1409
+
1410
+ # Optimize branch lengths using ML (with recursion protection)
1411
+ if tree:
1412
+ tree = self._optimize_branch_lengths_ml_safe(tree, alignment)
1413
+
1414
+ print("✓ ML tree constructed successfully")
1415
+ return tree
1416
+
1417
+ except Exception as e:
1418
+ print(f"Error constructing ML tree: {e}")
1419
+ return None
1420
+
1421
+ def _build_nj_tree_from_distances(self, distance_matrix: np.ndarray, sequence_names: List[str]) -> Optional[BaseTree.Tree]:
1422
+
1423
+ try:
1424
+ from Bio.Phylo.TreeConstruction import DistanceMatrix, DistanceTreeConstructor
1425
+
1426
+ # Validate inputs
1427
+ if distance_matrix.shape[0] != len(sequence_names):
1428
+ print("Error: Distance matrix size doesn't match sequence names")
1429
+ return None
1430
+
1431
+ # Convert numpy array to Bio.Phylo distance matrix format
1432
+ matrix_data = []
1433
+ for i in range(len(sequence_names)):
1434
+ row = []
1435
+ for j in range(i + 1):
1436
+ if i == j:
1437
+ row.append(0.0)
1438
+ else:
1439
+ # Ensure distance is valid
1440
+ dist = float(distance_matrix[i][j])
1441
+ if np.isnan(dist) or np.isinf(dist):
1442
+ dist = 1.0
1443
+ row.append(max(0.0, dist)) # Ensure non-negative
1444
+ matrix_data.append(row)
1445
+
1446
+ # Create DistanceMatrix object
1447
+ dm = DistanceMatrix(names=sequence_names, matrix=matrix_data)
1448
+
1449
+ # Build tree using Neighbor-Joining
1450
+ constructor = DistanceTreeConstructor()
1451
+ tree = constructor.nj(dm)
1452
+
1453
+ # Validate tree structure
1454
+ if tree and self._validate_tree_structure(tree):
1455
+ return tree
1456
+ else:
1457
+ print("Warning: Tree structure validation failed")
1458
+ return tree # Return anyway, might still be usable
1459
+
1460
+ except Exception as e:
1461
+ print(f"Error building NJ tree: {e}")
1462
+ return None
1463
+
1464
+ def _validate_tree_structure(self, tree: BaseTree.Tree, max_depth: int = 100) -> bool:
1465
+
1466
+ try:
1467
+ visited = set()
1468
+
1469
+ def check_node(node, depth=0):
1470
+ if depth > max_depth:
1471
+ return False
1472
+
1473
+ # Check for circular references
1474
+ node_id = id(node)
1475
+ if node_id in visited:
1476
+ return False
1477
+ visited.add(node_id)
1478
+
1479
+ # Check children
1480
+ for child in getattr(node, 'clades', []):
1481
+ if not check_node(child, depth + 1):
1482
+ return False
1483
+
1484
+ return True
1485
+
1486
+ return check_node(tree.root if hasattr(tree, 'root') else tree)
1487
+ except Exception:
1488
+ return False
1489
+
1490
+ def _optimize_branch_lengths_ml_safe(self, tree: BaseTree.Tree, alignment: MultipleSeqAlignment) -> BaseTree.Tree:
1491
+
1492
+ try:
1493
+ print("🔧 Optimizing branch lengths with ML...")
1494
+
1495
+ # Set recursion limit temporarily
1496
+ old_limit = sys.getrecursionlimit()
1497
+ sys.setrecursionlimit(1000)
1498
+
1499
+ try:
1500
+ # Convert alignment to matrix
1501
+ seq_matrix = self._alignment_to_matrix(alignment)
1502
+
1503
+ if seq_matrix.size == 0:
1504
+ print("Warning: Empty sequence matrix, skipping optimization")
1505
+ return tree
1506
+
1507
+ # Get all internal and external nodes with depth tracking
1508
+ all_clades = self._get_clades_safe(tree)
1509
+
1510
+ # Simple branch length optimization
1511
+ for clade in all_clades:
1512
+ if hasattr(clade, 'branch_length') and clade.branch_length is not None:
1513
+ try:
1514
+ # Calculate optimal branch length based on likelihood
1515
+ optimal_length = self._calculate_optimal_branch_length_safe(clade, seq_matrix)
1516
+ clade.branch_length = max(optimal_length, 0.001) # Minimum branch length
1517
+ except Exception as e:
1518
+ print(f"Warning: Failed to optimize branch for clade: {e}")
1519
+ # Keep original branch length
1520
+ pass
1521
+
1522
+ print("✓ Branch lengths optimized")
1523
+
1524
+ finally:
1525
+ # Restore original recursion limit
1526
+ sys.setrecursionlimit(old_limit)
1527
+
1528
+ return tree
1529
+
1530
+ except Exception as e:
1531
+ print(f"Warning: Branch length optimization failed: {e}")
1532
+ return tree
1533
+
1534
+ def _get_clades_safe(self, tree: BaseTree.Tree, max_depth: int = 50) -> List:
1535
+
1536
+ clades = []
1537
+ visited = set()
1538
+
1539
+ def traverse_node(node, depth=0):
1540
+ if depth > max_depth or id(node) in visited:
1541
+ return
1542
+
1543
+ visited.add(id(node))
1544
+ clades.append(node)
1545
+
1546
+ # Traverse children safely
1547
+ try:
1548
+ children = getattr(node, 'clades', [])
1549
+ for child in children:
1550
+ traverse_node(child, depth + 1)
1551
+ except Exception:
1552
+ pass # Skip problematic nodes
1553
+
1554
+ try:
1555
+ root = tree.root if hasattr(tree, 'root') else tree
1556
+ traverse_node(root)
1557
+ except Exception as e:
1558
+ print(f"Warning: Tree traversal error: {e}")
1559
+
1560
+ return clades
1561
+
1562
+ def _calculate_optimal_branch_length_safe(self, clade, seq_matrix: np.ndarray) -> float:
1563
+
1564
+ try:
1565
+ # Simplified ML branch length estimation
1566
+ if not hasattr(clade, 'branch_length') or clade.branch_length is None:
1567
+ return 0.1 # Default branch length
1568
+
1569
+ current_length = float(clade.branch_length)
1570
+
1571
+ # Validate current length
1572
+ if np.isnan(current_length) or np.isinf(current_length) or current_length <= 0:
1573
+ return 0.1
1574
+
1575
+ # Simple optimization based on sequence characteristics
1576
+ if hasattr(clade, 'name') and clade.name:
1577
+ # For terminal nodes
1578
+ return min(max(current_length * 0.9, 0.001), 1.0)
1579
+ else:
1580
+ # For internal nodes
1581
+ return min(max(current_length * 1.1, 0.001), 1.0)
1582
+
1583
+ except Exception:
1584
+ return 0.1 # Safe default
1585
+
1586
+ def calculate_ml_likelihood_safe(self, tree: BaseTree.Tree, alignment: MultipleSeqAlignment) -> float:
1587
+
1588
+ try:
1589
+ print("📈 Calculating tree likelihood...")
1590
+
1591
+ seq_matrix = self._alignment_to_matrix(alignment)
1592
+
1593
+ if seq_matrix.size == 0:
1594
+ return -np.inf
1595
+
1596
+ # Simplified likelihood calculation using Jukes-Cantor model
1597
+ total_log_likelihood = 0.0
1598
+
1599
+ # For each site in the alignment (sample subset to avoid memory issues)
1600
+ n_sites = min(seq_matrix.shape[1], 1000) # Limit sites for performance
1601
+
1602
+ for site in range(0, n_sites, max(1, n_sites // 100)): # Sample sites
1603
+ try:
1604
+ site_pattern = seq_matrix[:, site]
1605
+
1606
+ # Skip sites with gaps or N's
1607
+ valid_positions = site_pattern < 4
1608
+ if np.sum(valid_positions) < 2:
1609
+ continue
1610
+
1611
+ # Calculate likelihood for this site pattern
1612
+ site_likelihood = self._calculate_site_likelihood_safe(tree, site_pattern)
1613
+
1614
+ if site_likelihood > 0:
1615
+ total_log_likelihood += np.log(site_likelihood)
1616
+
1617
+ except Exception as e:
1618
+ print(f"Warning: Error processing site {site}: {e}")
1619
+ continue
1620
+
1621
+ print(f"✓ Tree likelihood calculated: {total_log_likelihood:.2f}")
1622
+ return total_log_likelihood
1623
+
1624
+ except Exception as e:
1625
+ print(f"Error calculating likelihood: {e}")
1626
+ return -np.inf
1627
+
1628
+ def _calculate_site_likelihood_safe(self, tree: BaseTree.Tree, site_pattern: np.ndarray) -> float:
1629
+
1630
+ try:
1631
+ # Count nucleotide frequencies at this site
1632
+ valid_nucs = site_pattern[site_pattern < 4]
1633
+
1634
+ if len(valid_nucs) == 0:
1635
+ return 1.0
1636
+
1637
+ # Simple likelihood based on nucleotide diversity
1638
+ unique_nucs = len(np.unique(valid_nucs))
1639
+ total_nucs = len(valid_nucs)
1640
+
1641
+ # Higher diversity = lower likelihood of simple evolution
1642
+ diversity_factor = unique_nucs / 4.0 # Normalize by 4 nucleotides
1643
+
1644
+ # Simple likelihood model
1645
+ likelihood = np.exp(-diversity_factor * total_nucs * 0.1)
1646
+
1647
+ return max(likelihood, 1e-10) # Avoid zero likelihood
1648
+
1649
+ except Exception:
1650
+ return 1e-10 # Safe fallback
1651
+
1652
+ def perform_ml_analysis_safe(self, matched_ids: List[str]) -> Dict:
1653
+
1654
+ try:
1655
+ print("\n🧬 PERFORMING MAXIMUM LIKELIHOOD ANALYSIS")
1656
+ print("="*50)
1657
+
1658
+ # Include query sequence in analysis
1659
+ all_sequences = [self.query_id] + [seq_id for seq_id in matched_ids if seq_id != self.query_id]
1660
+
1661
+ # Limit number of sequences to prevent memory issues
1662
+ if len(all_sequences) > 20:
1663
+ print(f"Warning: Limiting analysis to 20 sequences (had {len(all_sequences)})")
1664
+ all_sequences = all_sequences[:20]
1665
+
1666
+ if len(all_sequences) < 3:
1667
+ print("❌ Need at least 3 sequences for ML analysis")
1668
+ return {}
1669
+
1670
+ # Step 1: Create multiple sequence alignment
1671
+ alignment = self.create_sequence_alignment(all_sequences)
1672
+ if not alignment:
1673
+ return {}
1674
+
1675
+ # Step 2: Calculate ML distances
1676
+ distance_matrix = self.calculate_ml_distances(alignment)
1677
+ if distance_matrix.size == 0:
1678
+ return {}
1679
+
1680
+ # Step 3: Construct ML tree
1681
+ ml_tree = self.construct_ml_tree(alignment)
1682
+ if not ml_tree:
1683
+ return {}
1684
+
1685
+ # Step 4: Calculate tree likelihood (safely)
1686
+ log_likelihood = self.calculate_ml_likelihood_safe(ml_tree, alignment)
1687
+
1688
+ # Step 5: Prepare results
1689
+ ml_results = {
1690
+ 'tree': ml_tree,
1691
+ 'alignment': alignment,
1692
+ 'distance_matrix': distance_matrix,
1693
+ 'log_likelihood': log_likelihood,
1694
+ 'sequence_count': len(all_sequences),
1695
+ 'alignment_length': len(alignment[0]) if alignment else 0
1696
+ }
1697
+
1698
+ print(f"✅ ML analysis completed successfully")
1699
+ print(f" Sequences analyzed: {len(all_sequences)}")
1700
+ print(f" Alignment length: {ml_results['alignment_length']}")
1701
+ print(f" Log-likelihood: {log_likelihood:.2f}")
1702
+
1703
+ return ml_results
1704
+
1705
+ except Exception as e:
1706
+ print(f"❌ ML analysis failed: {e}")
1707
+ import traceback
1708
+ traceback.print_exc()
1709
+ return {}
1710
+
1711
+ def build_tree_structure_with_ml_safe(self, matched_ids: List[str]) -> Dict:
1712
+
1713
+ try:
1714
+ print("🌳 Building ML-enhanced tree structure...")
1715
+
1716
+ # Perform ML analysis first
1717
+ ml_results = self.perform_ml_analysis_safe(matched_ids)
1718
+
1719
+ # Build the original hierarchical structure
1720
+ tree_structure = self.build_tree_structure(matched_ids)
1721
+
1722
+ # Enhance with ML information
1723
+ if ml_results and 'tree' in ml_results:
1724
+ tree_structure['ml_analysis'] = {
1725
+ 'log_likelihood': ml_results['log_likelihood'],
1726
+ 'sequence_count': ml_results['sequence_count'],
1727
+ 'alignment_length': ml_results['alignment_length'],
1728
+ 'ml_tree_available': True
1729
+ }
1730
+
1731
+ # Store ML tree for later use
1732
+ self.ml_tree = ml_results['tree']
1733
+ self.ml_alignment = ml_results.get('alignment')
1734
+
1735
+ print("✓ Tree structure enhanced with ML analysis")
1736
+ else:
1737
+ tree_structure['ml_analysis'] = {
1738
+ 'ml_tree_available': False,
1739
+ 'error': 'ML analysis failed'
1740
+ }
1741
+ print("⚠️ ML analysis failed, using standard tree structure")
1742
+
1743
+ return tree_structure
1744
+
1745
+ except Exception as e:
1746
+ print(f"Error building ML-enhanced tree structure: {e}")
1747
+ # Fallback to original method
1748
+ try:
1749
+ return self.build_tree_structure(matched_ids)
1750
+ except Exception as e2:
1751
+ print(f"Fallback also failed: {e2}")
1752
+ return {'error': 'Both ML and standard tree construction failed'}
1753
+
1754
+
1755
+ def _print_tree_topology(self, tree, max_depth=3, current_depth=0, prefix=""):
1756
+
1757
+ if current_depth > max_depth:
1758
+ return
1759
+
1760
+ try:
1761
+ # Get all clades at current level
1762
+ clades = list(tree.find_clades())
1763
+
1764
+ for i, clade in enumerate(clades[:5]): # Limit to first 5 for readability
1765
+ branch_info = ""
1766
+ if clade.branch_length is not None:
1767
+ branch_info = f" (len: {clade.branch_length:.4f})"
1768
+
1769
+ if clade.is_terminal():
1770
+ node_name = clade.name or "Terminal"
1771
+ print(f" {prefix}├── {node_name}{branch_info}")
1772
+ else:
1773
+ node_name = clade.name or f"Internal_{i}"
1774
+ print(f" {prefix}├── {node_name}{branch_info}")
1775
+
1776
+ if current_depth < max_depth - 1 and not clade.is_terminal():
1777
+ # Show children (simplified)
1778
+ children = list(clade.find_clades())
1779
+ if len(children) > 1:
1780
+ for j, child in enumerate(children[1:3]): # Show max 2 children
1781
+ child_name = child.name or f"Node_{j}"
1782
+ child_branch = f" (len: {child.branch_length:.4f})" if child.branch_length else ""
1783
+ print(f" {prefix}│ ├── {child_name}{child_branch}")
1784
+
1785
+ except Exception as e:
1786
+ print(f" Error displaying topology: {e}")
1787
+
1788
+
1789
+
1790
+ def main():
1791
+ print("\n" + "="*70)
1792
+ print("🧬 PHYLOGENETIC TREE ANALYZER - ADVANCED ML-BASED ANALYSIS")
1793
+ print("="*70)
1794
+ print("Version 2.0 | AI-Enhanced Similarity Matching")
1795
+ print("Interactive Visualization with Variable Line Lengths")
1796
+ print("="*70)
1797
+
1798
+ # Initialize the analyzer
1799
+ analyzer = PhylogeneticTreeAnalyzer()
1800
+
1801
+ try:
1802
+ # Step 1: Load data
1803
+ while True:
1804
+ data_file = "f cleaned.csv"
1805
+ if not data_file:
1806
+ print("❌ Please provide a file path.")
1807
+ continue
1808
+
1809
+ if not Path(data_file).exists():
1810
+ print(f"❌ File not found: {data_file}")
1811
+ continue
1812
+
1813
+ if analyzer.load_data(data_file):
1814
+ break
1815
+ else:
1816
+ print("❌ Failed to load data. Please check file format.")
1817
+ continue
1818
+
1819
+ # Step 2: Train AI model automatically
1820
+ print("\n⏳ Training AI model... This may take a few moments.", flush=True)
1821
+ start_time = time.time()
1822
+ if analyzer.train_ai_model():
1823
+ elapsed = time.time() - start_time
1824
+ print(f"✅ AI model training completed in {elapsed:.1f} seconds", flush=True)
1825
+ else:
1826
+ print("⚠️ AI model training failed, continuing with basic analysis", flush=True)
1827
+
1828
+ # Step 3: Get query sequence
1829
+ while True:
1830
+ print("\n🔍 QUERY SEQUENCE INPUT:")
1831
+ print(" You can provide:")
1832
+ print(" 1. Accession Number (e.g., 'MH087032') - from your dataset")
1833
+ print(" 2. ANY F-gene nucleotide sequence (A, T, G, C)")
1834
+ print(" 3. Novel sequences will be compared against your dataset")
1835
+ print(" Note: Minimum sequence length is 10 nucleotides")
1836
+
1837
+ query_input = input("\nEnter query sequence or ID: ").strip()
1838
+ if not query_input:
1839
+ print("❌ Please provide a query sequence or ID.")
1840
+ continue
1841
+
1842
+ if analyzer.find_query_sequence(query_input):
1843
+ break
1844
+ else:
1845
+ retry = input("❌ Invalid input. Try again? (y/n): ").strip().lower()
1846
+ if retry != 'y':
1847
+ print("👋 Analysis cancelled.")
1848
+ return
1849
+
1850
+ # Step 4: Set similarity percentage
1851
+ while True:
1852
+ try:
1853
+ print(f"\n📊 SIMILARITY THRESHOLD:")
1854
+ print(f" - Higher values (90-99%): Find very similar sequences")
1855
+ print(f" - Lower values (70-89%): Find more distantly related sequences")
1856
+
1857
+ similarity_input = input(f"Enter target similarity percentage (1-99) [85]: ").strip()
1858
+ if not similarity_input:
1859
+ target_percentage = 85.0 # Lowered default for novel sequences
1860
+ else:
1861
+ target_percentage = float(similarity_input)
1862
+
1863
+ if not (1 <= target_percentage <= 99):
1864
+ print("❌ Please enter a percentage between 1 and 99.")
1865
+ continue
1866
+
1867
+ analyzer.matching_percentage = target_percentage
1868
+ break
1869
+
1870
+ except ValueError:
1871
+ print("❌ Please enter a valid number.")
1872
+ continue
1873
+
1874
+ # Step 5: Find similar sequences
1875
+ print(f"\n⏳ Analyzing sequences for {target_percentage}% similarity...")
1876
+ start_time = time.time()
1877
+
1878
+ matched_ids, actual_percentage = analyzer.find_similar_sequences(target_percentage)
1879
+
1880
+ if not matched_ids:
1881
+ print(f"❌ No similar sequences found at {target_percentage}% similarity.")
1882
+ print("💡 Try lowering the similarity percentage (e.g., 70-80%) to find more distant matches.")
1883
+ return
1884
+
1885
+ analyzer.matched_sequences = matched_ids
1886
+ analyzer.actual_percentage = actual_percentage
1887
+
1888
+ elapsed = time.time() - start_time
1889
+ print(f"✅ Similarity analysis completed in {elapsed:.1f} seconds")
1890
+
1891
+ # Step 6: Build tree structure
1892
+ print("\n⏳ Building phylogenetic tree structure...")
1893
+ start_time = time.time()
1894
+
1895
+ tree_structure = analyzer.build_tree_structure_with_ml_safe(matched_ids)
1896
+ if not tree_structure:
1897
+ print("❌ Failed to build tree structure.")
1898
+ return
1899
+
1900
+ elapsed = time.time() - start_time
1901
+ print(f"✅ Tree structure built in {elapsed:.1f} seconds")
1902
+
1903
+ # Step 7: Create visualization and save HTML
1904
+ print("\n⏳ Creating interactive visualization...")
1905
+ start_time = time.time()
1906
+
1907
+ fig = analyzer.create_interactive_tree(matched_ids, actual_percentage)
1908
+ if fig:
1909
+ elapsed = time.time() - start_time
1910
+ print(f"✅ Visualization created in {elapsed:.1f} seconds")
1911
+
1912
+ # Save the interactive HTML file
1913
+ html_filename = "phylogenetic_tree_interactive.html"
1914
+ fig.write_html(html_filename)
1915
+ print(f"📄 Interactive HTML saved: {html_filename}")
1916
+
1917
+ print(f"\n🎉 Analysis completed successfully!")
1918
+ print(f" Query ID: {analyzer.query_id}")
1919
+ print(f" Query sequence length: {len(analyzer.query_sequence)} nucleotides")
1920
+ print(f" Similar sequences found: {len(matched_ids)}")
1921
+ print(f" Actual similarity percentage: {actual_percentage:.1f}%")
1922
+ print(f" HTML file generated: {html_filename}")
1923
+ else:
1924
+ print("❌ Visualization creation failed.")
1925
+ return
1926
+
1927
+ except KeyboardInterrupt:
1928
+ print(f"\n\n⚠️ Analysis interrupted by user.")
1929
+ sys.exit(1)
1930
+ except Exception as e:
1931
+ print(f"\n❌ An error occurred during analysis: {e}")
1932
+ print(f"Please check your input data and try again.")
1933
+ sys.exit(1)
1934
+
1935
+
1936
+ def command_line_interface():
1937
+ parser = argparse.ArgumentParser(
1938
+ description="Advanced Phylogenetic Tree Analyzer with AI-enhanced similarity matching",
1939
+ formatter_class=argparse.RawDescriptionHelpFormatter,
1940
+ epilog="""
1941
+ Examples:
1942
+ # %(prog)s -d data.csv -q MH087032 -s 95
1943
+ # %(prog)s -d data.csv -q MH087032 -s 90 --no-ai --batch query1,query2,query3
1944
+ """
1945
+ )
1946
+
1947
+ parser.add_argument('-d', '--data', required=True,
1948
+ help='Path to CSV data file')
1949
+ parser.add_argument('-q', '--query', required=True,
1950
+ help='Query sequence ID or nucleotide sequence')
1951
+ parser.add_argument('-s', '--similarity', type=float, default=95.0,
1952
+ help='Target similarity percentage (70-99, default: 95)')
1953
+ parser.add_argument('--no-ai', action='store_true',
1954
+ help='Skip AI model training')
1955
+ parser.add_argument('--batch',
1956
+ help='Comma-separated list of query IDs for batch processing')
1957
+ parser.add_argument('--output-dir', default='.',
1958
+ help='Output directory for results')
1959
+ parser.add_argument('--save-json', action='store_true',
1960
+ help='Save detailed results to JSON')
1961
+
1962
+ args = parser.parse_args()
1963
+
1964
+ # Validate arguments
1965
+ if not (70 <= args.similarity <= 99):
1966
+ print("❌ Similarity percentage must be between 70 and 99.")
1967
+ sys.exit(1)
1968
+
1969
+ if not Path(args.data).exists():
1970
+ print(f"❌ Data file not found: {args.data}")
1971
+ sys.exit(1)
1972
+
1973
+ # Initialize analyzer
1974
+ analyzer = PhylogeneticTreeAnalyzer()
1975
+
1976
+ # Load data
1977
+ if not analyzer.load_data(args.data):
1978
+ print("❌ Failed to load data.")
1979
+ sys.exit(1)
1980
+
1981
+ # Train AI model (unless disabled)
1982
+ if not args.no_ai:
1983
+ print("\n⏳ Training AI model... This may take a few moments.", flush=True)
1984
+ start_time = time.time()
1985
+ if analyzer.train_ai_model():
1986
+ elapsed = time.time() - start_time
1987
+ print(f"✅ AI model training completed in {elapsed:.1f} seconds", flush=True)
1988
+ else:
1989
+ print("⚠️ AI model training failed, continuing with basic analysis", flush=True)
1990
+
1991
+ # Process queries
1992
+ queries = args.batch.split(',') if args.batch else [args.query]
1993
+
1994
+ for query in queries:
1995
+ query = query.strip()
1996
+ print(f"\n🔍 Processing: {query}")
1997
+
1998
+ if analyzer.find_query_sequence(query):
1999
+ matched_ids, actual_percentage = analyzer.find_similar_sequences(args.similarity)
2000
+
2001
+ if matched_ids:
2002
+ analyzer.build_tree_structure_with_ml_safe(matched_ids)
2003
+ fig = analyzer.create_interactive_tree(matched_ids, actual_percentage)
2004
+
2005
+ if fig:
2006
+ # Save the interactive HTML file
2007
+ html_filename = f"phylogenetic_tree_{query.replace('/', '_')}_interactive.html"
2008
+ fig.write_html(html_filename)
2009
+ print(f"📄 Interactive HTML saved: {html_filename}")
2010
+
2011
+ print(f"✅ Analysis completed for {query}")
2012
+ else:
2013
+ print(f"❌ No similar sequences found for {query}")
2014
+ else:
2015
+ print(f"❌ Query not found: {query}")
2016
+
2017
+
2018
+ if __name__ == "__main__":
2019
+ try:
2020
+ main()
2021
+ except KeyboardInterrupt:
2022
+ print(f"\n\n👋 Goodbye!")
2023
+ sys.exit(0)
2024
+ except Exception as e:
2025
+ print(f"\n❌ Unexpected error: {e}")
2026
+ sys.exit(1)
2027
+ #KR815908