re-type commited on
Commit
183a83a
·
verified ·
1 Parent(s): ae7bc5c

Delete ml_simplified_tree.py

Browse files
Files changed (1) hide show
  1. ml_simplified_tree.py +0 -2026
ml_simplified_tree.py DELETED
@@ -1,2026 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """ML simplified tree.ipynb
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1LiDjip-h70ilIex9PedpWCZARWglija7
8
- """
9
-
10
-
11
- # Commented out IPython magic to ensure Python compatibility.
12
- import pandas as pd
13
- import numpy as np
14
- import plotly.graph_objects as go
15
- import plotly.offline as pyo
16
- from plotly.subplots import make_subplots
17
- from Bio import Phylo, SeqIO, AlignIO
18
- from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor
19
- from Bio.Align import MultipleSeqAlignment
20
- from Bio.Seq import Seq
21
- from Bio.SeqRecord import SeqRecord
22
- from sklearn.feature_extraction.text import CountVectorizer
23
- from sklearn.metrics.pairwise import cosine_similarity
24
- from sklearn.ensemble import RandomForestClassifier
25
- from sklearn.model_selection import train_test_split
26
- from sklearn.preprocessing import LabelEncoder
27
- import warnings
28
- import os
29
- import sys
30
- from typing import Dict, List, Tuple, Optional, Any
31
- import json
32
- import re
33
- from scipy.optimize import minimize
34
- from scipy.spatial.distance import pdist, squareform
35
- from Bio.Phylo import BaseTree
36
- import itertools
37
- from collections import defaultdict, deque
38
- import argparse
39
- import time
40
- from pathlib import Path
41
-
42
- warnings.filterwarnings('ignore')
43
-
44
- class PhylogeneticTreeAnalyzer:
45
-
46
- def __init__(self):
47
-
48
- self.data = None
49
- self.query_sequence = None
50
- self.query_id = None
51
- self.matching_percentage = 95.0
52
- self.actual_percentage = None
53
- self.matched_sequences = []
54
- self.tree_structure = {}
55
- self.similarity_scores = {}
56
- self.ai_model = None
57
- self.label_encoder = LabelEncoder()
58
- # ML-specific attributes
59
- self.ml_tree = None
60
- self.ml_alignment = None
61
- self.ml_results = {}
62
- self.horizontal_line_tracker = [] # Track horizontal lines with verticals
63
- self.query_ml_group = None # Track which ML group contains the query
64
- self.base_horizontal_length = 1.2 # Base length for horizontal lines
65
-
66
- def load_data(self, data_file: str):
67
-
68
- try:
69
- self.data = pd.read_csv(data_file)
70
- # required_columns = ['Accession Number', 'ML', 'Genotype', 'Host',
71
- # 'Country', 'Isolate', 'Year', 'F-gene']
72
-
73
- # missing_columns = [col for col in self.data.columns if col not in required_columns] # Corrected check for missing columns
74
- # if missing_columns:
75
- # print(f"Error: Missing required columns: {missing_columns}")
76
- # return False
77
-
78
- print(f"✓ Data loaded successfully: {len(self.data)} sequences")
79
- print(f"✓ ML Groups found: {self.data['ML'].nunique()}")
80
- print(f"✓ Genotypes found: {self.data['Genotype'].nunique()}")
81
- return True
82
-
83
- except Exception as e:
84
- print(f"Error loading data: {e}")
85
- return False
86
-
87
-
88
- def calculate_f_gene_similarity(self, seq1: str, seq2: str) -> float:
89
-
90
- try:
91
- # Handle empty or None sequences
92
- if not seq1 or not seq2:
93
- return 0.0
94
-
95
- # Convert to uppercase and remove non-nucleotide characters
96
- seq1 = re.sub(r'[^ATGC]', '', str(seq1).upper())
97
- seq2 = re.sub(r'[^ATGC]', '', str(seq2).upper())
98
-
99
- if len(seq1) == 0 or len(seq2) == 0:
100
- return 0.0
101
-
102
- # Use k-mer analysis for similarity calculation
103
- k = 5 # 5-mer analysis
104
- kmers1 = set([seq1[i:i+k] for i in range(len(seq1)-k+1) if len(seq1[i:i+k]) == k])
105
- kmers2 = set([seq2[i:i+k] for i in range(len(seq2)-k+1) if len(seq2[i:i+k]) == k])
106
-
107
- if len(kmers1) == 0 and len(kmers2) == 0:
108
- return 100.0
109
- elif len(kmers1) == 0 or len(kmers2) == 0:
110
- return 0.0
111
-
112
- # Calculate Jaccard similarity
113
- intersection = len(kmers1.intersection(kmers2))
114
- union = len(kmers1.union(kmers2))
115
- similarity = (intersection / union) * 100 if union > 0 else 0.0
116
-
117
- return round(similarity, 2)
118
-
119
- except Exception as e:
120
- print(f"Error calculating similarity: {e}")
121
- return 0.0
122
-
123
- def train_ai_model(self):
124
-
125
- try:
126
-
127
- # Skip training if insufficient data
128
- if len(self.data) < 10: # Require minimum 10 samples
129
- print("⚠️ Insufficient data to train AI model (min 10 samples required)", flush=True)
130
- return False
131
-
132
- print("🤖 Training AI model for sequence analysis...", flush=True)
133
-
134
- # Prepare features from F-gene sequences
135
- f_gene_sequences = self.data['F-gene'].fillna('').astype(str)
136
-
137
- # Create k-mer features (3-mers to 6-mers)
138
- features = []
139
- for seq in f_gene_sequences:
140
- seq_clean = re.sub(r'[^ATGC]', '', seq.upper())
141
- if len(seq_clean) < 3:
142
- features.append([0] * 100) # Placeholder for short sequences
143
- continue
144
-
145
- feature_vector = []
146
- # 3-mers
147
- kmers_3 = [seq_clean[i:i+3] for i in range(len(seq_clean)-2)]
148
- kmer_counts_3 = {kmer: kmers_3.count(kmer) for kmer in set(kmers_3)}
149
-
150
- # 4-mers
151
- kmers_4 = [seq_clean[i:i+4] for i in range(len(seq_clean)-3)]
152
- kmer_counts_4 = {kmer: kmers_4.count(kmer) for kmer in set(kmers_4)}
153
-
154
- # Create feature vector (top 50 3-mers + top 50 4-mers)
155
- all_3mers = [''.join(p) for p in __import__('itertools').product('ATGC', repeat=3)]
156
- all_4mers = [''.join(p) for p in __import__('itertools').product('ATGC', repeat=4)]
157
-
158
- feature_vector.extend([kmer_counts_3.get(kmer, 0) for kmer in all_3mers[:50]])
159
- feature_vector.extend([kmer_counts_4.get(kmer, 0) for kmer in all_4mers[:50]])
160
-
161
- features.append(feature_vector)
162
-
163
- # Prepare target labels (ML groups)
164
- targets = self.label_encoder.fit_transform(self.data['ML'].fillna('Unknown'))
165
-
166
- # Skip if only 1 class
167
- if len(np.unique(targets)) < 2:
168
- print("⚠️ Need at least 2 distinct classes for training", flush=True)
169
- return False
170
-
171
- # Train Random Forest model
172
- X = np.array(features)
173
- y = targets
174
-
175
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
176
-
177
- self.ai_model = RandomForestClassifier(n_estimators=100, random_state=42)
178
- self.ai_model.fit(X_train, y_train)
179
-
180
- # Calculate accuracy
181
- accuracy = self.ai_model.score(X_test, y_test)
182
- print(f"✓ AI model trained successfully with accuracy: {accuracy:.2%}", flush=True)
183
-
184
- return True
185
-
186
- except Exception as e:
187
- print(f"🚨 CRITICAL training error: {e}", flush=True)
188
- import traceback
189
- traceback.print_exc()
190
- return False
191
-
192
- def find_query_sequence(self, query_input: str) -> bool:
193
- """
194
- Modified to accept any sequence input, not just those existing in the dataset.
195
- """
196
- try:
197
- # Check if input is an accession number from the dataset
198
- if query_input in self.data['Accession Number'].values:
199
- self.query_id = query_input
200
- query_row = self.data[self.data['Accession Number'] == query_input].iloc[0]
201
- self.query_sequence = query_row['F-gene']
202
- print(f"✓ Query sequence found by ID: {query_input}")
203
- return True
204
-
205
- # Check if input is a nucleotide sequence
206
- query_clean = re.sub(r'[^ATGC]', '', str(query_input).upper())
207
-
208
- # Accept any sequence with reasonable length (even short ones for testing)
209
- if len(query_clean) >= 10: # Minimum sequence length (reduced from 50)
210
- # For sequences not in dataset, create a unique identifier
211
- if query_input not in self.data['Accession Number'].values:
212
- # Generate a unique query ID for novel sequences
213
- self.query_id = f"QUERY_{hash(query_clean) % 100000:05d}"
214
- self.query_sequence = query_clean
215
- print(f"✓ Novel query sequence accepted with ID: {self.query_id}")
216
- print(f" Sequence length: {len(query_clean)} nucleotides")
217
- return True
218
- else:
219
- # If somehow it matches an accession but wasn't caught above
220
- self.query_id = query_input
221
- self.query_sequence = query_clean
222
- print(f"✓ Query sequence processed: {query_input}")
223
- return True
224
-
225
- # If sequence is too short or invalid
226
- if len(query_clean) < 10:
227
- print(f"✗ Query sequence too short. Minimum length: 10 nucleotides (provided: {len(query_clean)})")
228
- else:
229
- print(f"✗ Invalid sequence format. Please provide nucleotides (A, T, G, C) or valid accession number")
230
-
231
- return False
232
-
233
- except Exception as e:
234
- print(f"Error processing query sequence: {e}")
235
- return False
236
-
237
- def find_similar_sequences(self, target_percentage: float) -> Tuple[List[str], float]:
238
- """
239
- Modified to work with any query sequence, including novel ones not in the dataset.
240
- """
241
- try:
242
- print(f"🔍 Finding sequences with {target_percentage}% similarity to query...")
243
- similarities = []
244
-
245
- # Calculate similarity between query and all sequences in dataset
246
- for idx, row in self.data.iterrows():
247
- # Skip if this is the same sequence (only relevant for existing accession numbers)
248
- if hasattr(self, 'query_id') and row['Accession Number'] == self.query_id:
249
- continue
250
-
251
- try:
252
- similarity = self.calculate_f_gene_similarity(self.query_sequence, row['F-gene'])
253
- similarities.append({
254
- 'id': row['Accession Number'],
255
- 'similarity': similarity,
256
- 'ml': row['ML'] if 'ML' in row else 'Unknown',
257
- 'genotype': row['Genotype'] if 'Genotype' in row else 'Unknown'
258
- })
259
- except Exception as seq_error:
260
- print(f"⚠ Skipping sequence {row['Accession Number']}: {seq_error}")
261
- continue
262
-
263
- if not similarities:
264
- print("❌ No valid sequences found for comparison")
265
- return [], target_percentage
266
-
267
- # Sort by similarity (highest first)
268
- similarities.sort(key=lambda x: x['similarity'], reverse=True)
269
-
270
- # Find sequences within target percentage range (±2%)
271
- target_range = 2.0
272
- candidates = [s for s in similarities
273
- if abs(s['similarity'] - target_percentage) <= target_range]
274
-
275
- if not candidates:
276
- # If no exact matches, find sequences with closest similarity
277
- closest_sim = min(similarities, key=lambda x: abs(x['similarity'] - target_percentage))
278
- actual_percentage = closest_sim['similarity']
279
-
280
- # Get sequences within ±1% of the closest similarity
281
- candidates = [s for s in similarities
282
- if abs(s['similarity'] - actual_percentage) <= 1.0]
283
-
284
- print(f"⚠ No sequences found at exactly {target_percentage}%. Using closest: {actual_percentage:.1f}%")
285
- else:
286
- actual_percentage = target_percentage
287
-
288
- # Limit results to prevent overwhelming visualization (optional)
289
- max_results = 50 # Adjust as needed
290
- if len(candidates) > max_results:
291
- candidates = candidates[:max_results]
292
- print(f"⚠ Limited results to top {max_results} matches for better visualization")
293
-
294
- # Store similarity scores for later use
295
- self.similarity_scores = {} # Reset similarity scores
296
- for candidate in candidates:
297
- self.similarity_scores[candidate['id']] = candidate['similarity']
298
-
299
- matched_ids = [c['id'] for c in candidates]
300
-
301
- # Show some statistics
302
- if similarities:
303
- max_sim = max(similarities, key=lambda x: x['similarity'])['similarity']
304
- min_sim = min(similarities, key=lambda x: x['similarity'])['similarity']
305
- avg_sim = sum(s['similarity'] for s in similarities) / len(similarities)
306
-
307
- print(f"✓ Found {len(matched_ids)} sequences at ~{actual_percentage:.1f}% similarity")
308
- print(f" Similarity range in dataset: {min_sim:.1f}% - {max_sim:.1f}% (avg: {avg_sim:.1f}%)")
309
-
310
- return matched_ids, actual_percentage
311
-
312
- except Exception as e:
313
- print(f"Error finding similar sequences: {e}")
314
- return [], target_percentage
315
-
316
-
317
- def build_tree_structure(self, matched_ids: List[str]) -> Dict:
318
- try:
319
- print("🌳 Building normalized horizontal tree structure...")
320
-
321
- # Initialize tree structure
322
- tree_structure = {
323
- 'root': {
324
- 'name': 'Root',
325
- 'type': 'root',
326
- 'children': {},
327
- 'x': 0,
328
- 'y': 0,
329
- 'has_vertical_attachment': False,
330
- 'extension_level': 0
331
- }
332
- }
333
-
334
- # Group sequences by ML and Genotype
335
- ml_groups = {}
336
- for idx, row in self.data.iterrows():
337
- ml_group = row['ML']
338
- genotype = row['Genotype']
339
- seq_id = row['Accession Number']
340
-
341
- if ml_group not in ml_groups:
342
- ml_groups[ml_group] = {}
343
-
344
- if genotype not in ml_groups[ml_group]:
345
- ml_groups[ml_group][genotype] = []
346
-
347
- ml_groups[ml_group][genotype].append({
348
- 'id': seq_id,
349
- 'data': row.to_dict(),
350
- 'is_query': seq_id == self.query_id,
351
- 'is_matched': seq_id in matched_ids,
352
- 'similarity': self.similarity_scores.get(seq_id, 0.0)
353
- })
354
-
355
- # Normalize ML group names and organize
356
- normalized_ml_groups = self._normalize_ml_groups(ml_groups)
357
-
358
- # Build normalized ML level - horizontal layout with progressive extensions
359
- self._build_normalized_ml_nodes(tree_structure, normalized_ml_groups, matched_ids)
360
-
361
- self.tree_structure = tree_structure
362
- print(f"✓ Normalized horizontal tree structure built")
363
- return tree_structure
364
-
365
- except Exception as e:
366
- print(f"Error building tree structure: {e}")
367
- return {}
368
-
369
- def _normalize_ml_groups(self, ml_groups: Dict) -> Dict:
370
- """Normalize ML group names and organize hierarchically"""
371
- try:
372
- normalized_groups = {}
373
-
374
- for ml_name, genotypes in ml_groups.items():
375
- # Extract base ML name
376
- if ml_name.startswith('UNCL'):
377
- base_ml = 'UNCL'
378
- elif '.' in ml_name and any(char.isdigit() for char in ml_name):
379
- # For names like XII.1.2, XII.1, etc., extract the base (XII)
380
- base_ml = ml_name.split('.')[0]
381
- else:
382
- base_ml = ml_name
383
-
384
- # Initialize normalized group structure
385
- if base_ml not in normalized_groups:
386
- normalized_groups[base_ml] = {
387
- 'full_ml_groups': {},
388
- 'representative_sequences': [],
389
- 'has_special_sequences': False
390
- }
391
-
392
- # Check if this ML group has query or matched sequences
393
- has_special = any(
394
- any(seq['is_query'] or seq['is_matched'] for seq in sequences)
395
- for sequences in genotypes.values()
396
- )
397
-
398
- if has_special:
399
- normalized_groups[base_ml]['has_special_sequences'] = True
400
- normalized_groups[base_ml]['full_ml_groups'][ml_name] = genotypes
401
- else:
402
- # Add as representative (limit to 2 representatives)
403
- if len(normalized_groups[base_ml]['representative_sequences']) < 2:
404
- # Get 1-2 representative sequences from this ML group
405
- for genotype, sequences in list(genotypes.items())[:2]:
406
- if len(normalized_groups[base_ml]['representative_sequences']) < 2:
407
- normalized_groups[base_ml]['representative_sequences'].extend(sequences[:1])
408
-
409
- return normalized_groups
410
-
411
- except Exception as e:
412
- print(f"Error normalizing ML groups: {e}")
413
- return {}
414
-
415
- def _build_normalized_ml_nodes(self, tree_structure: Dict, normalized_ml_groups: Dict, matched_ids: List[str]):
416
- """Build normalized ML nodes with equal spacing and progressive horizontal extensions"""
417
- try:
418
- # Reset horizontal line tracker
419
- self.horizontal_line_tracker = []
420
-
421
- # Identify which ML group contains the query
422
- self._identify_query_ml_group(normalized_ml_groups)
423
-
424
- # Calculate equal spacing for all ML groups
425
- ml_positions = self._calculate_dynamic_ml_positions(normalized_ml_groups)
426
-
427
- # Mark root as having vertical attachment if it has multiple children
428
- root_has_vertical = len(normalized_ml_groups) > 1
429
- tree_structure['root']['has_vertical_attachment'] = root_has_vertical
430
-
431
- for ml_idx, (base_ml, ml_data) in enumerate(normalized_ml_groups.items()):
432
- y_pos = ml_positions[ml_idx]
433
-
434
- # Determine if this ML node will have vertical attachments
435
- has_vertical = ml_data['has_special_sequences'] and len(ml_data['full_ml_groups']) > 1
436
-
437
- # Check if this ML group contains the query
438
- contains_query = (base_ml == self.query_ml_group)
439
-
440
- # Calculate horizontal line length based on connections and query presence
441
- horizontal_length = self._determine_horizontal_line_length(
442
- 'normalized_ml_group', has_vertical, contains_query
443
- )
444
-
445
- x_pos = horizontal_length
446
-
447
- # Create normalized ML node
448
- tree_structure['root']['children'][base_ml] = {
449
- 'name': base_ml,
450
- 'type': 'normalized_ml_group',
451
- 'children': {},
452
- 'x': x_pos,
453
- 'y': y_pos,
454
- 'has_special_sequences': ml_data['has_special_sequences'],
455
- 'has_vertical_attachment': has_vertical,
456
- 'horizontal_line_length': horizontal_length,
457
- 'contains_query': contains_query
458
- }
459
-
460
- if ml_data['has_special_sequences']:
461
- # Build full ML nodes for groups with special sequences
462
- self._build_full_ml_nodes(
463
- tree_structure['root']['children'][base_ml],
464
- ml_data['full_ml_groups'],
465
- y_pos,
466
- matched_ids,
467
- x_pos
468
- )
469
- else:
470
- # Add representative sequences directly
471
- self._add_representative_sequences(
472
- tree_structure['root']['children'][base_ml],
473
- ml_data['representative_sequences'],
474
- y_pos,
475
- x_pos
476
- )
477
-
478
- except Exception as e:
479
- print(f"Error building normalized ML nodes: {e}")
480
-
481
- def _calculate_dynamic_ml_positions(self, normalized_ml_groups: Dict) -> List[float]:
482
- """Calculate equal Y positions for all ML groups regardless of content"""
483
- try:
484
- ml_count = len(normalized_ml_groups)
485
- if ml_count == 0:
486
- return []
487
-
488
- if ml_count == 1:
489
- return [0.0]
490
-
491
- # Equal spacing between all ML nodes
492
- total_spacing = (ml_count - 1) * 2.0 # 2.0 units between each ML node
493
- start_y = -total_spacing / 2
494
-
495
- positions = []
496
- for i in range(ml_count):
497
- positions.append(start_y + i * 2.0)
498
-
499
- return positions
500
-
501
- except Exception as e:
502
- print(f"Error calculating dynamic positions: {e}")
503
- return list(range(len(normalized_ml_groups)))
504
-
505
- def _build_full_ml_nodes(self, normalized_ml_node: Dict, full_ml_groups: Dict, base_y: float, matched_ids: List[str], parent_x: float):
506
- """Build full ML nodes with genotypes for groups containing special sequences"""
507
- try:
508
- # Calculate equal positions for full ML groups
509
- full_ml_positions = self._calculate_full_ml_positions(full_ml_groups, base_y)
510
-
511
- for ml_idx, (full_ml_name, genotypes) in enumerate(full_ml_groups.items()):
512
- y_pos = full_ml_positions[ml_idx]
513
-
514
- # Determine if this full ML node will have vertical attachments
515
- special_genotypes_count = sum(1 for genotype, sequences in genotypes.items()
516
- if any(seq['is_query'] or seq['is_matched'] for seq in sequences))
517
- has_vertical = special_genotypes_count > 1
518
-
519
- # Check if this full ML group contains the query
520
- contains_query = any(
521
- any(seq['is_query'] for seq in sequences)
522
- for sequences in genotypes.values()
523
- )
524
-
525
- # Calculate horizontal line length
526
- horizontal_length = self._determine_horizontal_line_length(
527
- 'full_ml_group', has_vertical, contains_query
528
- )
529
-
530
- x_pos = parent_x + horizontal_length
531
-
532
- # Create full ML node
533
- normalized_ml_node['children'][full_ml_name] = {
534
- 'name': full_ml_name,
535
- 'type': 'full_ml_group',
536
- 'children': {},
537
- 'x': x_pos,
538
- 'y': y_pos,
539
- 'sequences_count': sum(len(seqs) for seqs in genotypes.values()),
540
- 'has_vertical_attachment': has_vertical,
541
- 'horizontal_line_length': horizontal_length,
542
- 'contains_query': contains_query
543
- }
544
-
545
- # Build genotype nodes
546
- self._build_genotype_nodes(
547
- normalized_ml_node['children'][full_ml_name],
548
- genotypes,
549
- y_pos,
550
- matched_ids,
551
- x_pos
552
- )
553
-
554
- except Exception as e:
555
- print(f"Error building full ML nodes: {e}")
556
-
557
- def _calculate_full_ml_positions(self, full_ml_groups: Dict, base_y: float) -> List[float]:
558
- """Calculate equal positions for full ML groups"""
559
- try:
560
- ml_count = len(full_ml_groups)
561
- if ml_count <= 1:
562
- return [base_y]
563
-
564
- # Equal spacing for full ML groups
565
- spacing = 1.5 # Fixed spacing between full ML groups
566
- start_y = base_y - (spacing * (ml_count - 1)) / 2
567
-
568
- positions = []
569
- for i in range(ml_count):
570
- positions.append(start_y + i * spacing)
571
-
572
- return positions
573
-
574
- except Exception as e:
575
- print(f"Error calculating full ML positions: {e}")
576
- return [base_y] * len(full_ml_groups)
577
-
578
- def _build_genotype_nodes(self, full_ml_node: Dict, genotypes: Dict, base_y: float, matched_ids: List[str], parent_x: float):
579
- """Build genotype nodes with sequences - horizontal line length based on sequence count"""
580
- try:
581
- # Filter genotypes with special sequences
582
- special_genotypes = []
583
- for genotype, sequences in genotypes.items():
584
- if any(seq['is_query'] or seq['is_matched'] for seq in sequences):
585
- special_genotypes.append((genotype, sequences))
586
-
587
- if not special_genotypes:
588
- return
589
-
590
- # Calculate equal genotype positions (vertical positioning remains equal)
591
- genotype_positions = self._calculate_genotype_positions(special_genotypes, base_y)
592
-
593
- # Calculate sequence counts for each genotype to determine horizontal line lengths
594
- genotype_sequence_counts = []
595
- for genotype, sequences in special_genotypes:
596
- special_sequences = [seq for seq in sequences if seq['is_query'] or seq['is_matched']]
597
- genotype_sequence_counts.append((genotype, sequences, len(special_sequences)))
598
-
599
- for gt_idx, (genotype, sequences, sequence_count) in enumerate(genotype_sequence_counts):
600
- y_pos = genotype_positions[gt_idx]
601
-
602
- # Determine if this genotype will have vertical attachments
603
- special_sequences = [seq for seq in sequences if seq['is_query'] or seq['is_matched']]
604
- has_vertical = len(special_sequences) > 1
605
-
606
- # Check if this genotype contains the query
607
- contains_query = any(seq['is_query'] for seq in sequences)
608
-
609
- # Calculate horizontal line length based on sequence count
610
- horizontal_length = self._determine_genotype_horizontal_line_length(
611
- sequence_count, has_vertical, contains_query
612
- )
613
-
614
- x_pos = parent_x + horizontal_length
615
-
616
- # Create genotype node
617
- full_ml_node['children'][genotype] = {
618
- 'name': genotype,
619
- 'type': 'genotype',
620
- 'children': {},
621
- 'x': x_pos,
622
- 'y': y_pos,
623
- 'sequences': sequences,
624
- 'has_vertical_attachment': has_vertical,
625
- 'horizontal_line_length': horizontal_length,
626
- 'contains_query': contains_query,
627
- 'sequence_count': sequence_count # Store for reference
628
- }
629
-
630
- # Add sequences horizontally
631
- self._add_sequences_horizontal(
632
- full_ml_node['children'][genotype],
633
- sequences,
634
- y_pos,
635
- x_pos
636
- )
637
-
638
- except Exception as e:
639
- print(f"Error building genotype nodes: {e}")
640
-
641
- def _determine_genotype_horizontal_line_length(self, sequence_count: int, has_vertical: bool, contains_query: bool = False) -> float:
642
- """Determine horizontal line length for genotype nodes based on sequence count"""
643
- try:
644
- base_length = self.base_horizontal_length
645
-
646
- # Special case: Genotype containing query sequence gets additional length
647
- query_bonus = 0.5 if contains_query else 0.0
648
-
649
- # Calculate length based on sequence count
650
- # More sequences = longer horizontal line
651
- if sequence_count <= 1:
652
- # Single sequence
653
- length_multiplier = 1.0
654
- elif sequence_count <= 3:
655
- # 2-3 sequences
656
- length_multiplier = 1.6
657
- elif sequence_count <= 5:
658
- # 4-5 sequences
659
- length_multiplier = 2.3
660
- elif sequence_count <= 8:
661
- # 6-8 sequences
662
- length_multiplier = 6.0
663
- else:
664
- # More than 8 sequences
665
- length_multiplier = 6.0
666
-
667
- # Calculate final length
668
- calculated_length = base_length * length_multiplier + query_bonus
669
-
670
- return calculated_length
671
-
672
- except Exception as e:
673
- print(f"Error determining genotype horizontal line length: {e}")
674
- return self.base_horizontal_length
675
-
676
- def _calculate_genotype_positions(self, special_genotypes: List, base_y: float) -> List[float]:
677
- """Calculate equal positions for genotypes"""
678
- try:
679
- genotype_count = len(special_genotypes)
680
- if genotype_count <= 1:
681
- return [base_y]
682
-
683
- # Equal spacing for genotypes
684
- spacing = 1.0 # Fixed spacing between genotypes
685
- start_y = base_y - (spacing * (genotype_count - 1)) / 2
686
-
687
- positions = []
688
- for i in range(genotype_count):
689
- positions.append(start_y + i * spacing)
690
-
691
- return positions
692
-
693
- except Exception as e:
694
- print(f"Error calculating genotype positions: {e}")
695
- return [base_y] * len(special_genotypes)
696
-
697
- def _add_representative_sequences(self, normalized_ml_node: Dict, representative_sequences: List[Dict], base_y: float, parent_x: float):
698
- """Add representative sequences directly to normalized ML node"""
699
- try:
700
- if not representative_sequences:
701
- return
702
-
703
- # Calculate horizontal line length for representative sequences
704
- # Representative sequences get a standard length (not similarity-based since they're not matched)
705
- has_vertical = len(representative_sequences) > 1
706
- horizontal_length = self._determine_horizontal_line_length('representative', has_vertical)
707
- x_pos = parent_x + horizontal_length
708
-
709
- if len(representative_sequences) == 1:
710
- seq = representative_sequences[0]
711
- normalized_ml_node['children'][f"{seq['id']}_rep"] = {
712
- 'name': f"{seq['id']} (Rep)",
713
- 'type': 'representative_sequence',
714
- 'data': seq,
715
- 'x': x_pos,
716
- 'y': base_y,
717
- 'has_vertical_attachment': False,
718
- 'horizontal_line_length': horizontal_length
719
- }
720
- else:
721
- # Equal spacing for multiple representative sequences
722
- positions = self._calculate_sequence_positions(representative_sequences, base_y)
723
-
724
- for idx, seq in enumerate(representative_sequences):
725
- normalized_ml_node['children'][f"{seq['id']}_rep"] = {
726
- 'name': f"{seq['id']} (Rep)",
727
- 'type': 'representative_sequence',
728
- 'data': seq,
729
- 'x': x_pos,
730
- 'y': positions[idx],
731
- 'has_vertical_attachment': False,
732
- 'horizontal_line_length': horizontal_length
733
- }
734
-
735
- except Exception as e:
736
- print(f"Error adding representative sequences: {e}")
737
-
738
- def _add_sequences_horizontal(self, genotype_node: Dict, sequences: List[Dict], base_y: float, parent_x: float):
739
- """Add sequences horizontally with similarity-based line lengths"""
740
- try:
741
- # Define the query line length as the reference (100%)
742
- query_line_length = 3.0 # Base length for query sequence (100%)
743
-
744
- # Separate query and matched sequences
745
- query_sequences = [seq for seq in sequences if seq['is_query']]
746
- matched_sequences = [seq for seq in sequences if seq['is_matched'] and not seq['is_query']]
747
-
748
- all_special_sequences = query_sequences + matched_sequences
749
-
750
- if len(all_special_sequences) == 1:
751
- # Single sequence - direct line with similarity-based length
752
- sequence = all_special_sequences[0]
753
- line_length = self._calculate_similarity_based_line_length(sequence, query_line_length)
754
- x_pos = parent_x + line_length
755
-
756
- genotype_node['children'][sequence['id']] = {
757
- 'name': f"{sequence['id']}{' (' + str(sequence['similarity']) + '%)' if sequence['is_matched'] else ''}",
758
- 'type': 'sequence',
759
- 'data': sequence,
760
- 'x': x_pos,
761
- 'y': base_y,
762
- 'has_vertical_attachment': False,
763
- 'similarity_line_length': line_length
764
- }
765
- else:
766
- # Multiple sequences - equal vertical distribution with similarity-based horizontal lengths
767
- sequence_positions = self._calculate_sequence_positions(all_special_sequences, base_y)
768
-
769
- for seq_idx, sequence in enumerate(all_special_sequences):
770
- line_length = self._calculate_similarity_based_line_length(sequence, query_line_length)
771
- x_pos = parent_x + line_length
772
-
773
- genotype_node['children'][sequence['id']] = {
774
- 'name': f"{sequence['id']}{' (' + str(sequence['similarity']) + '%)' if sequence['is_matched'] else ''}",
775
- 'type': 'sequence',
776
- 'data': sequence,
777
- 'x': x_pos,
778
- 'y': sequence_positions[seq_idx],
779
- 'has_vertical_attachment': False,
780
- 'similarity_line_length': line_length
781
- }
782
- except Exception as e:
783
- print(f"Error adding sequences horizontally: {e}")
784
-
785
- def _calculate_similarity_based_line_length(self, sequence: Dict, query_line_length: float) -> float:
786
- """Calculate line length based on similarity percentage relative to query"""
787
- try:
788
- if sequence['is_query']:
789
- # Query sequence gets 100% length
790
- return query_line_length
791
- elif sequence['is_matched']:
792
- # Matched sequences get length proportional to their similarity
793
- similarity = sequence['similarity']
794
- # Convert similarity percentage to proportional length
795
- proportional_length = (similarity / 100.0) * query_line_length
796
- # Ensure minimum length for visibility
797
- min_length = query_line_length * 0.2 # Minimum 20% of query length
798
- return max(proportional_length, min_length)
799
- else:
800
- # Other sequences get a standard length (50% of query)
801
- return query_line_length * 0.5
802
- except Exception as e:
803
- print(f"Error calculating similarity-based line length: {e}")
804
- return query_line_length * 0.5
805
-
806
-
807
- def _calculate_sequence_positions(self, sequences: List[Dict], base_y: float) -> List[float]:
808
- """Calculate equal positions for sequences"""
809
- try:
810
- seq_count = len(sequences)
811
- if seq_count <= 1:
812
- return [base_y]
813
-
814
- # Equal spacing for sequences
815
- spacing = 0.8 # Fixed spacing between sequences
816
- start_y = base_y - (spacing * (seq_count - 1)) / 2
817
-
818
- positions = []
819
- for i in range(seq_count):
820
- positions.append(start_y + i * spacing)
821
-
822
- return positions
823
-
824
- except Exception as e:
825
- print(f"Error calculating sequence positions: {e}")
826
- return [base_y] * len(sequences)
827
-
828
- def _determine_horizontal_line_length(self, node_type: str, has_vertical: bool, contains_query: bool = False) -> float:
829
- """Determine horizontal line length based on node type and connections"""
830
- try:
831
- base_length = self.base_horizontal_length
832
-
833
- # Special case: ML group containing query sequence gets much longer line
834
- if contains_query and node_type == 'normalized_ml_group':
835
- return base_length * 2.5 # Much longer for query ML group
836
-
837
- # If this node has a vertical line attachment (connects to multiple children)
838
- if has_vertical:
839
- # Find the current longest horizontal line with vertical
840
- current_max = base_length
841
- for tracked_length in self.horizontal_line_tracker:
842
- if tracked_length > current_max:
843
- current_max = tracked_length
844
-
845
- # Make this line incrementally longer
846
- new_length = current_max + 0.3
847
- self.horizontal_line_tracker.append(new_length)
848
- return new_length
849
- else:
850
- # Direct connection (no vertical), use base length
851
- return base_length
852
-
853
- except Exception as e:
854
- print(f"Error determining horizontal line length: {e}")
855
- return self.base_horizontal_length
856
-
857
- def _identify_query_ml_group(self, normalized_ml_groups: Dict):
858
- """Identify which ML group contains the query sequence"""
859
- try:
860
- for base_ml, ml_data in normalized_ml_groups.items():
861
- if ml_data['has_special_sequences']:
862
- for full_ml_name, genotypes in ml_data['full_ml_groups'].items():
863
- for genotype, sequences in genotypes.items():
864
- if any(seq['is_query'] for seq in sequences):
865
- self.query_ml_group = base_ml
866
- return
867
- except Exception as e:
868
- print(f"Error identifying query ML group: {e}")
869
-
870
- def _identify_query_ml_group(self, normalized_ml_groups: Dict):
871
- """Identify which ML group contains the query sequence"""
872
- try:
873
- for base_ml, ml_data in normalized_ml_groups.items():
874
- if ml_data['has_special_sequences']:
875
- for full_ml_name, genotypes in ml_data['full_ml_groups'].items():
876
- for genotype, sequences in genotypes.items():
877
- if any(seq['is_query'] for seq in sequences):
878
- self.query_ml_group = base_ml
879
- return
880
- except Exception as e:
881
- print(f"Error identifying query ML group: {e}")
882
-
883
- def _calculate_sequence_x_position_horizontal(self, sequence: Dict, max_similarity: float) -> float:
884
- """Calculate X position based on similarity percentage for horizontal layout"""
885
- # This function is now replaced by _calculate_similarity_based_line_length
886
- # Keeping for backward compatibility, but the new approach is used in _add_sequences_horizontal
887
-
888
- base_x = 0 # Relative to parent genotype node
889
- query_line_length = 3.0 # Reference length for query (100%)
890
-
891
- if sequence['is_query']:
892
- return base_x + query_line_length # 100% length for query
893
- elif sequence['is_matched']:
894
- # Line length varies based on similarity percentage
895
- similarity = sequence['similarity']
896
- proportional_length = (similarity / 100.0) * query_line_length
897
- min_length = query_line_length * 0.2 # Minimum 20% of query length
898
- return base_x + max(proportional_length, min_length)
899
- else:
900
- return base_x + (query_line_length * 0.5) # 50% length for other sequences
901
-
902
-
903
- def create_interactive_tree(self, matched_ids: List[str], actual_percentage: float):
904
- try:
905
- print("🎨 Creating horizontal interactive tree visualization...")
906
-
907
- # Prepare data for plotting
908
- edge_x = []
909
- edge_y = []
910
- node_x = []
911
- node_y = []
912
- node_colors = []
913
- node_text = []
914
- node_hover = []
915
- node_sizes = []
916
-
917
- # Updated color scheme for new node types
918
- colors = {
919
- 'root': '#FF0000', # Red for root
920
- 'normalized_ml_group': '#FFB6C1', # Light pink for normalized ML groups
921
- 'full_ml_group': '#FF69B4', # Hot pink for full ML groups
922
- 'genotype': '#FFD700', # Gold for genotypes
923
- 'representative_sequence': '#FFA500', # Orange for representative sequences
924
- 'query_sequence': '#4B0082', # Dark purple for query
925
- 'matched_sequence': '#6A5ACD', # Slate blue for matched
926
- 'other_sequence': '#87CEEB' # Sky blue for others
927
- }
928
-
929
- def add_horizontal_edges(parent_x, parent_y, children_dict):
930
- """Add horizontal connecting lines with proper vertical line sizing"""
931
- if not children_dict:
932
- return
933
-
934
- children_list = list(children_dict.values())
935
-
936
- if len(children_list) == 1:
937
- # Single child - direct horizontal line
938
- child = children_list[0]
939
- edge_x.extend([parent_x, child['x'], None])
940
- edge_y.extend([parent_y, child['y'], None])
941
- else:
942
- # Multiple children - horizontal line with vertical distribution
943
- # Calculate the intermediate x position (where vertical line will be)
944
- child_x_positions = [child['x'] for child in children_list]
945
- min_child_x = min(child_x_positions)
946
- intermediate_x = parent_x + (min_child_x - parent_x) * 0.8 # 80% of the way to nearest child
947
-
948
- # Horizontal line to intermediate point
949
- edge_x.extend([parent_x, intermediate_x, None])
950
- edge_y.extend([parent_y, parent_y, None])
951
-
952
- # Calculate vertical line range to fit exactly all children
953
- child_y_positions = [child['y'] for child in children_list]
954
- min_y, max_y = min(child_y_positions), max(child_y_positions)
955
-
956
- # Vertical line sized exactly to fit all children
957
- edge_x.extend([intermediate_x, intermediate_x, None])
958
- edge_y.extend([min_y, max_y, None])
959
-
960
- # Horizontal lines from vertical line to each child
961
- for child in children_list:
962
- edge_x.extend([intermediate_x, child['x'], None])
963
- edge_y.extend([child['y'], child['y'], None])
964
-
965
- def get_node_color_and_size(node):
966
- """Determine node color and size based on type and content"""
967
- if node['type'] == 'sequence':
968
- if node['data']['is_query']:
969
- return colors['query_sequence'], 10 # Reduced size for compactness
970
- elif node['data']['is_matched']:
971
- return colors['matched_sequence'], 8
972
- else:
973
- return colors['other_sequence'], 6
974
- elif node['type'] == 'representative_sequence':
975
- return colors['representative_sequence'], 7
976
- elif node['type'] == 'normalized_ml_group':
977
- # Larger size if it has special sequences
978
- size = 9 if node.get('has_special_sequences', False) else 7
979
- return colors['normalized_ml_group'], size
980
- elif node['type'] == 'full_ml_group':
981
- return colors['full_ml_group'], 8
982
- elif node['type'] == 'genotype':
983
- return colors['genotype'], 7
984
- else:
985
- return colors.get(node['type'], '#000000'), 7
986
-
987
- def create_node_text(node):
988
- """Create appropriate text label for each node type"""
989
- if node['type'] == 'sequence':
990
- if node['data']['is_matched'] and not node['data']['is_query']:
991
- return f"{node['name']}"
992
- else:
993
- return node['name']
994
- elif node['type'] == 'representative_sequence':
995
- return node['name']
996
- elif node['type'] == 'normalized_ml_group':
997
- # Add indicator if it has special sequences
998
- suffix = " *" if node.get('has_special_sequences', False) else ""
999
- return f"{node['name']}{suffix}"
1000
- else:
1001
- return node['name']
1002
-
1003
- def create_hover_text(node):
1004
- """Create detailed hover text for each node type"""
1005
- if node['type'] == 'sequence':
1006
- data = node['data']['data']
1007
- hover_text = (
1008
- f"<b>{node['name']}</b><br>"
1009
- f"Type: {'Query Sequence' if node['data']['is_query'] else 'Matched Sequence' if node['data']['is_matched'] else 'Other Sequence'}<br>"
1010
- f"ML Group: {data.get('ML', 'N/A')}<br>"
1011
- f"Genotype: {data.get('Genotype', 'N/A')}<br>"
1012
- f"Host: {data.get('Host', 'N/A')}<br>"
1013
- f"Country: {data.get('Country', 'N/A')}<br>"
1014
- f"Isolate: {data.get('Isolate', 'N/A')}<br>"
1015
- f"Year: {data.get('Year', 'N/A')}"
1016
- )
1017
- if node['data']['is_matched']:
1018
- hover_text += f"<br><b>Similarity: {node['data']['similarity']}%</b>"
1019
- elif node['type'] == 'representative_sequence':
1020
- data = node['data']['data']
1021
- hover_text = (
1022
- f"<b>{node['name']}</b><br>"
1023
- f"Type: Representative Sequence<br>"
1024
- f"ML Group: {data.get('ML', 'N/A')}<br>"
1025
- f"Genotype: {data.get('Genotype', 'N/A')}<br>"
1026
- f"Host: {data.get('Host', 'N/A')}<br>"
1027
- f"Country: {data.get('Country', 'N/A')}"
1028
- )
1029
- elif node['type'] == 'normalized_ml_group':
1030
- hover_text = f"<b>{node['name']}</b><br>Type: Normalized ML Group"
1031
- if node.get('has_special_sequences', False):
1032
- hover_text += "<br>Contains query/matched sequences"
1033
- else:
1034
- hover_text += "<br>Representative sequences only"
1035
- elif node['type'] == 'full_ml_group':
1036
- hover_text = f"<b>{node['name']}</b><br>Type: Full ML Group"
1037
- if 'sequences_count' in node:
1038
- hover_text += f"<br>Total Sequences: {node['sequences_count']}"
1039
- elif node['type'] == 'genotype':
1040
- hover_text = f"<b>{node['name']}</b><br>Type: Genotype"
1041
- if 'sequences' in node:
1042
- special_count = sum(1 for seq in node['sequences'] if seq['is_query'] or seq['is_matched'])
1043
- hover_text += f"<br>Special Sequences: {special_count}/{len(node['sequences'])}"
1044
- else:
1045
- hover_text = f"<b>{node['name']}</b><br>Type: {node['type'].replace('_', ' ').title()}"
1046
-
1047
- return hover_text
1048
-
1049
- def add_node_and_edges(node, parent_x=None, parent_y=None):
1050
- """Recursively add nodes and edges to the plot with equal spacing structure."""
1051
- x, y = node['x'], node['y']
1052
- node_x.append(x)
1053
- node_y.append(y)
1054
-
1055
- # Get node color and size
1056
- color, size = get_node_color_and_size(node)
1057
- node_colors.append(color)
1058
- node_sizes.append(size)
1059
-
1060
- # Create node text and hover
1061
- node_text.append(create_node_text(node))
1062
- node_hover.append(create_hover_text(node))
1063
-
1064
- # Process children with equal spacing structure
1065
- if 'children' in node and node['children']:
1066
- add_horizontal_edges(x, y, node['children'])
1067
- for child in node['children'].values():
1068
- add_node_and_edges(child, x, y)
1069
-
1070
- # Build the plot data starting from root
1071
- root_node = self.tree_structure['root']
1072
- add_node_and_edges(root_node)
1073
-
1074
- # Add horizontal edges for root level
1075
- if root_node['children']:
1076
- add_horizontal_edges(root_node['x'], root_node['y'], root_node['children'])
1077
-
1078
- # Create the figure
1079
- fig = go.Figure()
1080
-
1081
- # Add edges
1082
- fig.add_trace(go.Scatter(
1083
- x=edge_x, y=edge_y,
1084
- mode='lines',
1085
- line=dict(width=1, color='gray', dash='solid'), # Thinner lines for compactness
1086
- hoverinfo='none',
1087
- showlegend=False,
1088
- name='Edges'
1089
- ))
1090
-
1091
- # Add nodes
1092
- fig.add_trace(go.Scatter(
1093
- x=node_x, y=node_y,
1094
- mode='markers+text',
1095
- marker=dict(
1096
- size=node_sizes,
1097
- color=node_colors,
1098
- line=dict(width=1, color='black'), # Thinner borders
1099
- opacity=0.85
1100
- ),
1101
- text=node_text,
1102
- textposition="middle right",
1103
- textfont=dict(size=9, color="black"), # Smaller font for compactness
1104
- hoverinfo='text',
1105
- hovertext=node_hover,
1106
- showlegend=False,
1107
- name='Nodes'
1108
- ))
1109
-
1110
- # Calculate proper layout dimensions to ensure everything fits
1111
- if node_x and node_y:
1112
- # Get the actual data bounds
1113
- min_x, max_x = min(node_x), max(node_x)
1114
- min_y, max_y = min(node_y), max(node_y)
1115
-
1116
- # Calculate ranges
1117
- x_range = max_x - min_x
1118
- y_range = max_y - min_y
1119
-
1120
- # Add padding to ensure nothing is cut off (20% padding on each side)
1121
- x_padding = x_range * 0.2 if x_range > 0 else 1
1122
- y_padding = y_range * 0.2 if y_range > 0 else 1
1123
-
1124
- # Set axis ranges with padding
1125
- x_axis_range = [min_x - x_padding, max_x + x_padding]
1126
- y_axis_range = [min_y - y_padding, max_y + y_padding]
1127
-
1128
- # Compact but sufficient sizing
1129
- width = min(1400, max(800, int(x_range * 80 + 400))) # Cap max width
1130
- height = min(900, max(500, int(y_range * 40 + 300))) # Cap max height
1131
- else:
1132
- width, height = 800, 500
1133
- x_axis_range = None
1134
- y_axis_range = None
1135
-
1136
- # Update layout for compact horizontal tree with proper bounds
1137
- fig.update_layout(
1138
- title=dict(
1139
- text=f"Compact Horizontal Phylogenetic Tree (ML-Based)<br>"
1140
- f"Query: {self.query_id} | Similarity: {actual_percentage}% | "
1141
- f"Matched: {len(matched_ids)}",
1142
- x=0.5,
1143
- font=dict(size=12) # Smaller title for compactness
1144
- ),
1145
- xaxis=dict(
1146
- showgrid=False,
1147
- gridcolor='lightgray',
1148
- gridwidth=0.3, # Very thin grid lines
1149
- zeroline=False,
1150
- showticklabels=False,
1151
- range=x_axis_range, # Set explicit range to prevent cutoff
1152
- fixedrange=False, # Allow zooming if needed
1153
- automargin=True # Automatically adjust margins
1154
- ),
1155
- yaxis=dict(
1156
- showgrid=False,
1157
- gridcolor='lightgray',
1158
- gridwidth=0.3, # Very thin grid lines
1159
- zeroline=False,
1160
- showticklabels=False,
1161
- range=y_axis_range, # Set explicit range to prevent cutoff
1162
- fixedrange=False, # Allow zooming if needed
1163
- automargin=True # Automatically adjust margins
1164
- ),
1165
- plot_bgcolor="white",
1166
- paper_bgcolor="white",
1167
- hovermode="closest",
1168
- width=width,
1169
- height=height,
1170
- margin=dict(l=20, r=100, t=40, b=10), # Adequate margins, extra right margin for text
1171
- autosize=False, # Don't auto-resize
1172
- showlegend=True,
1173
- legend=dict(
1174
- x=1.02, # Position legend outside plot area
1175
- y=1,
1176
- xanchor='left',
1177
- yanchor='top',
1178
- bgcolor='rgba(255,255,255,0.8)',
1179
- bordercolor='gray',
1180
- borderwidth=1,
1181
- font=dict(size=10) # Smaller legend font
1182
- )
1183
- )
1184
-
1185
- # Add comprehensive legend with smaller markers
1186
- legend_elements = [
1187
- dict(name="Root", marker=dict(color=colors['root'], size=8)),
1188
- dict(name="Normalized ML Groups", marker=dict(color=colors['normalized_ml_group'], size=8)),
1189
- dict(name="Full ML Groups", marker=dict(color=colors['full_ml_group'], size=8)),
1190
- dict(name="Genotypes", marker=dict(color=colors['genotype'], size=8)),
1191
- dict(name="Query Sequence", marker=dict(color=colors['query_sequence'], size=10)),
1192
- dict(name="Similar Sequences", marker=dict(color=colors['matched_sequence'], size=9)),
1193
- dict(name="Representative Sequences", marker=dict(color=colors['representative_sequence'], size=8)),
1194
- dict(name="Other Sequences", marker=dict(color=colors['other_sequence'], size=7))
1195
- ]
1196
-
1197
- for i, element in enumerate(legend_elements):
1198
- fig.add_trace(go.Scatter(
1199
- x=[None], y=[None],
1200
- mode='markers',
1201
- marker=element['marker'],
1202
- name=element['name'],
1203
- showlegend=True
1204
- ))
1205
-
1206
-
1207
- # Configure modebar for better user experience
1208
- config = {
1209
- 'displayModeBar': True,
1210
- 'displaylogo': False,
1211
- 'modeBarButtonsToRemove': ['select2d', 'lasso2d'],
1212
- 'toImageButtonOptions': {
1213
- 'format': 'png',
1214
- 'filename': 'phylogenetic_tree',
1215
- 'height': height,
1216
- 'width': width,
1217
- 'scale': 2
1218
- }
1219
- }
1220
-
1221
- # Save outputs
1222
- try:
1223
- fig.write_html("phylogenetic_tree_normalized_horizontal.html", config=config)
1224
- print("✓ Compact horizontal interactive tree saved as 'phylogenetic_tree_normalized_horizontal.html'")
1225
- except Exception as e:
1226
- print(f"Warning: Could not save HTML file: {e}")
1227
-
1228
- # Display the figure with config
1229
- try:
1230
- fig.show(config=config)
1231
- except Exception as e:
1232
- print(f"Warning: Could not display figure: {e}")
1233
-
1234
- return fig
1235
-
1236
- except Exception as e:
1237
- print(f"Error creating compact horizontal interactive tree: {e}")
1238
- return None
1239
-
1240
-
1241
- def create_sequence_alignment(self, sequence_ids: List[str]) -> Optional[MultipleSeqAlignment]:
1242
-
1243
- try:
1244
- print("🧬 Creating multiple sequence alignment...")
1245
-
1246
- # Get sequences
1247
- sequences = []
1248
- for seq_id in sequence_ids:
1249
- try:
1250
- row = self.data[self.data['Accession Number'] == seq_id]
1251
- if not row.empty:
1252
- f_gene = str(row.iloc[0]['F-gene'])
1253
- # Clean sequence (remove non-nucleotide characters)
1254
- clean_seq = re.sub(r'[^ATGCN-]', '', f_gene.upper())
1255
- if len(clean_seq) > 10: # Minimum sequence length
1256
- seq_record = SeqRecord(Seq(clean_seq), id=seq_id, description="")
1257
- sequences.append(seq_record)
1258
- except Exception as e:
1259
- print(f"Warning: Skipping sequence {seq_id}: {e}")
1260
- continue
1261
-
1262
- if len(sequences) < 2:
1263
- print("❌ Need at least 2 valid sequences for alignment")
1264
- return None
1265
-
1266
- # Simple alignment (you might want to use MUSCLE or CLUSTAL for better results)
1267
- aligned_sequences = self._simple_alignment(sequences)
1268
-
1269
- print(f"✓ Alignment created with {len(aligned_sequences)} sequences")
1270
- return MultipleSeqAlignment(aligned_sequences)
1271
-
1272
- except Exception as e:
1273
- print(f"Error creating alignment: {e}")
1274
- return None
1275
-
1276
- def _simple_alignment(self, sequences: List[SeqRecord]) -> List[SeqRecord]:
1277
-
1278
- try:
1279
- # Find maximum length
1280
- max_length = max(len(seq.seq) for seq in sequences)
1281
-
1282
- # Cap maximum length to prevent memory issues
1283
- if max_length > 10000:
1284
- max_length = 10000
1285
- print(f"Warning: Sequences truncated to {max_length} bp")
1286
-
1287
- # Pad sequences to same length
1288
- aligned_sequences = []
1289
- for seq in sequences:
1290
- seq_str = str(seq.seq)[:max_length] # Truncate if too long
1291
-
1292
- if len(seq_str) < max_length:
1293
- # Pad with gaps at the end
1294
- padded_seq = seq_str + '-' * (max_length - len(seq_str))
1295
- else:
1296
- padded_seq = seq_str
1297
-
1298
- aligned_sequences.append(SeqRecord(Seq(padded_seq), id=seq.id, description=seq.description))
1299
-
1300
- return aligned_sequences
1301
- except Exception as e:
1302
- print(f"Error in simple alignment: {e}")
1303
- return sequences # Return original sequences as fallback
1304
-
1305
- def calculate_ml_distances(self, alignment: MultipleSeqAlignment) -> np.ndarray:
1306
-
1307
- try:
1308
- print("📊 Calculating ML distances...")
1309
-
1310
- # Convert alignment to numeric matrix
1311
- seq_matrix = self._alignment_to_matrix(alignment)
1312
- n_sequences = len(alignment)
1313
-
1314
- if n_sequences == 0:
1315
- return np.array([])
1316
-
1317
- # Initialize distance matrix
1318
- distance_matrix = np.zeros((n_sequences, n_sequences))
1319
-
1320
- # Calculate pairwise ML distances
1321
- for i in range(n_sequences):
1322
- for j in range(i + 1, n_sequences):
1323
- try:
1324
- ml_distance = self._calculate_ml_distance_pair(seq_matrix[i], seq_matrix[j])
1325
- distance_matrix[i][j] = ml_distance
1326
- distance_matrix[j][i] = ml_distance
1327
- except Exception as e:
1328
- print(f"Warning: Error calculating distance between sequences {i} and {j}: {e}")
1329
- # Use maximum distance as fallback
1330
- distance_matrix[i][j] = 1.0
1331
- distance_matrix[j][i] = 1.0
1332
-
1333
- print("✓ ML distances calculated")
1334
- return distance_matrix
1335
-
1336
- except Exception as e:
1337
- print(f"Error calculating ML distances: {e}")
1338
- return np.array([])
1339
-
1340
- def _alignment_to_matrix(self, alignment: MultipleSeqAlignment) -> np.ndarray:
1341
-
1342
- try:
1343
- # Nucleotide to number mapping
1344
- nucleotide_map = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4, '-': 5}
1345
-
1346
- matrix = []
1347
- for record in alignment:
1348
- sequence = str(record.seq).upper()
1349
- numeric_seq = [nucleotide_map.get(nuc, 4) for nuc in sequence]
1350
- matrix.append(numeric_seq)
1351
-
1352
- return np.array(matrix)
1353
- except Exception as e:
1354
- print(f"Error converting alignment to matrix: {e}")
1355
- return np.array([])
1356
-
1357
- def _calculate_ml_distance_pair(self, seq1: np.ndarray, seq2: np.ndarray) -> float:
1358
-
1359
- try:
1360
- if len(seq1) == 0 or len(seq2) == 0:
1361
- return 1.0
1362
-
1363
- # Count differences (excluding gaps and N's)
1364
- valid_positions = (seq1 < 4) & (seq2 < 4) # Exclude N's and gaps
1365
-
1366
- if np.sum(valid_positions) == 0:
1367
- return 1.0 # Maximum distance if no valid comparisons
1368
-
1369
- differences = np.sum(seq1[valid_positions] != seq2[valid_positions])
1370
- total_valid = np.sum(valid_positions)
1371
-
1372
- if total_valid == 0:
1373
- return 1.0
1374
-
1375
- # Calculate proportion of differences
1376
- p = differences / total_valid
1377
-
1378
- # Jukes-Cantor correction
1379
- if p >= 0.75:
1380
- return 1.0 # Maximum distance
1381
-
1382
- # JC distance formula: -3/4 * ln(1 - 4p/3)
1383
- try:
1384
- jc_distance = -0.75 * np.log(1 - (4 * p / 3))
1385
- return min(max(jc_distance, 0.0), 1.0) # Clamp between 0 and 1
1386
- except (ValueError, RuntimeWarning):
1387
- return 1.0 # Return maximum distance if log calculation fails
1388
-
1389
- except Exception as e:
1390
- return 1.0 # Return maximum distance on error
1391
-
1392
- def construct_ml_tree(self, alignment: MultipleSeqAlignment) -> Optional[BaseTree.Tree]:
1393
-
1394
- try:
1395
- print("🌳 Constructing Maximum Likelihood tree...")
1396
-
1397
- # Calculate ML distance matrix
1398
- distance_matrix = self.calculate_ml_distances(alignment)
1399
-
1400
- if distance_matrix.size == 0:
1401
- return None
1402
-
1403
- # Create sequence names list
1404
- sequence_names = [record.id for record in alignment]
1405
-
1406
- # Build tree using neighbor-joining on ML distances
1407
- tree = self._build_nj_tree_from_distances(distance_matrix, sequence_names)
1408
-
1409
- # Optimize branch lengths using ML (with recursion protection)
1410
- if tree:
1411
- tree = self._optimize_branch_lengths_ml_safe(tree, alignment)
1412
-
1413
- print("✓ ML tree constructed successfully")
1414
- return tree
1415
-
1416
- except Exception as e:
1417
- print(f"Error constructing ML tree: {e}")
1418
- return None
1419
-
1420
- def _build_nj_tree_from_distances(self, distance_matrix: np.ndarray, sequence_names: List[str]) -> Optional[BaseTree.Tree]:
1421
-
1422
- try:
1423
- from Bio.Phylo.TreeConstruction import DistanceMatrix, DistanceTreeConstructor
1424
-
1425
- # Validate inputs
1426
- if distance_matrix.shape[0] != len(sequence_names):
1427
- print("Error: Distance matrix size doesn't match sequence names")
1428
- return None
1429
-
1430
- # Convert numpy array to Bio.Phylo distance matrix format
1431
- matrix_data = []
1432
- for i in range(len(sequence_names)):
1433
- row = []
1434
- for j in range(i + 1):
1435
- if i == j:
1436
- row.append(0.0)
1437
- else:
1438
- # Ensure distance is valid
1439
- dist = float(distance_matrix[i][j])
1440
- if np.isnan(dist) or np.isinf(dist):
1441
- dist = 1.0
1442
- row.append(max(0.0, dist)) # Ensure non-negative
1443
- matrix_data.append(row)
1444
-
1445
- # Create DistanceMatrix object
1446
- dm = DistanceMatrix(names=sequence_names, matrix=matrix_data)
1447
-
1448
- # Build tree using Neighbor-Joining
1449
- constructor = DistanceTreeConstructor()
1450
- tree = constructor.nj(dm)
1451
-
1452
- # Validate tree structure
1453
- if tree and self._validate_tree_structure(tree):
1454
- return tree
1455
- else:
1456
- print("Warning: Tree structure validation failed")
1457
- return tree # Return anyway, might still be usable
1458
-
1459
- except Exception as e:
1460
- print(f"Error building NJ tree: {e}")
1461
- return None
1462
-
1463
- def _validate_tree_structure(self, tree: BaseTree.Tree, max_depth: int = 100) -> bool:
1464
-
1465
- try:
1466
- visited = set()
1467
-
1468
- def check_node(node, depth=0):
1469
- if depth > max_depth:
1470
- return False
1471
-
1472
- # Check for circular references
1473
- node_id = id(node)
1474
- if node_id in visited:
1475
- return False
1476
- visited.add(node_id)
1477
-
1478
- # Check children
1479
- for child in getattr(node, 'clades', []):
1480
- if not check_node(child, depth + 1):
1481
- return False
1482
-
1483
- return True
1484
-
1485
- return check_node(tree.root if hasattr(tree, 'root') else tree)
1486
- except Exception:
1487
- return False
1488
-
1489
- def _optimize_branch_lengths_ml_safe(self, tree: BaseTree.Tree, alignment: MultipleSeqAlignment) -> BaseTree.Tree:
1490
-
1491
- try:
1492
- print("🔧 Optimizing branch lengths with ML...")
1493
-
1494
- # Set recursion limit temporarily
1495
- old_limit = sys.getrecursionlimit()
1496
- sys.setrecursionlimit(1000)
1497
-
1498
- try:
1499
- # Convert alignment to matrix
1500
- seq_matrix = self._alignment_to_matrix(alignment)
1501
-
1502
- if seq_matrix.size == 0:
1503
- print("Warning: Empty sequence matrix, skipping optimization")
1504
- return tree
1505
-
1506
- # Get all internal and external nodes with depth tracking
1507
- all_clades = self._get_clades_safe(tree)
1508
-
1509
- # Simple branch length optimization
1510
- for clade in all_clades:
1511
- if hasattr(clade, 'branch_length') and clade.branch_length is not None:
1512
- try:
1513
- # Calculate optimal branch length based on likelihood
1514
- optimal_length = self._calculate_optimal_branch_length_safe(clade, seq_matrix)
1515
- clade.branch_length = max(optimal_length, 0.001) # Minimum branch length
1516
- except Exception as e:
1517
- print(f"Warning: Failed to optimize branch for clade: {e}")
1518
- # Keep original branch length
1519
- pass
1520
-
1521
- print("✓ Branch lengths optimized")
1522
-
1523
- finally:
1524
- # Restore original recursion limit
1525
- sys.setrecursionlimit(old_limit)
1526
-
1527
- return tree
1528
-
1529
- except Exception as e:
1530
- print(f"Warning: Branch length optimization failed: {e}")
1531
- return tree
1532
-
1533
- def _get_clades_safe(self, tree: BaseTree.Tree, max_depth: int = 50) -> List:
1534
-
1535
- clades = []
1536
- visited = set()
1537
-
1538
- def traverse_node(node, depth=0):
1539
- if depth > max_depth or id(node) in visited:
1540
- return
1541
-
1542
- visited.add(id(node))
1543
- clades.append(node)
1544
-
1545
- # Traverse children safely
1546
- try:
1547
- children = getattr(node, 'clades', [])
1548
- for child in children:
1549
- traverse_node(child, depth + 1)
1550
- except Exception:
1551
- pass # Skip problematic nodes
1552
-
1553
- try:
1554
- root = tree.root if hasattr(tree, 'root') else tree
1555
- traverse_node(root)
1556
- except Exception as e:
1557
- print(f"Warning: Tree traversal error: {e}")
1558
-
1559
- return clades
1560
-
1561
- def _calculate_optimal_branch_length_safe(self, clade, seq_matrix: np.ndarray) -> float:
1562
-
1563
- try:
1564
- # Simplified ML branch length estimation
1565
- if not hasattr(clade, 'branch_length') or clade.branch_length is None:
1566
- return 0.1 # Default branch length
1567
-
1568
- current_length = float(clade.branch_length)
1569
-
1570
- # Validate current length
1571
- if np.isnan(current_length) or np.isinf(current_length) or current_length <= 0:
1572
- return 0.1
1573
-
1574
- # Simple optimization based on sequence characteristics
1575
- if hasattr(clade, 'name') and clade.name:
1576
- # For terminal nodes
1577
- return min(max(current_length * 0.9, 0.001), 1.0)
1578
- else:
1579
- # For internal nodes
1580
- return min(max(current_length * 1.1, 0.001), 1.0)
1581
-
1582
- except Exception:
1583
- return 0.1 # Safe default
1584
-
1585
- def calculate_ml_likelihood_safe(self, tree: BaseTree.Tree, alignment: MultipleSeqAlignment) -> float:
1586
-
1587
- try:
1588
- print("📈 Calculating tree likelihood...")
1589
-
1590
- seq_matrix = self._alignment_to_matrix(alignment)
1591
-
1592
- if seq_matrix.size == 0:
1593
- return -np.inf
1594
-
1595
- # Simplified likelihood calculation using Jukes-Cantor model
1596
- total_log_likelihood = 0.0
1597
-
1598
- # For each site in the alignment (sample subset to avoid memory issues)
1599
- n_sites = min(seq_matrix.shape[1], 1000) # Limit sites for performance
1600
-
1601
- for site in range(0, n_sites, max(1, n_sites // 100)): # Sample sites
1602
- try:
1603
- site_pattern = seq_matrix[:, site]
1604
-
1605
- # Skip sites with gaps or N's
1606
- valid_positions = site_pattern < 4
1607
- if np.sum(valid_positions) < 2:
1608
- continue
1609
-
1610
- # Calculate likelihood for this site pattern
1611
- site_likelihood = self._calculate_site_likelihood_safe(tree, site_pattern)
1612
-
1613
- if site_likelihood > 0:
1614
- total_log_likelihood += np.log(site_likelihood)
1615
-
1616
- except Exception as e:
1617
- print(f"Warning: Error processing site {site}: {e}")
1618
- continue
1619
-
1620
- print(f"✓ Tree likelihood calculated: {total_log_likelihood:.2f}")
1621
- return total_log_likelihood
1622
-
1623
- except Exception as e:
1624
- print(f"Error calculating likelihood: {e}")
1625
- return -np.inf
1626
-
1627
- def _calculate_site_likelihood_safe(self, tree: BaseTree.Tree, site_pattern: np.ndarray) -> float:
1628
-
1629
- try:
1630
- # Count nucleotide frequencies at this site
1631
- valid_nucs = site_pattern[site_pattern < 4]
1632
-
1633
- if len(valid_nucs) == 0:
1634
- return 1.0
1635
-
1636
- # Simple likelihood based on nucleotide diversity
1637
- unique_nucs = len(np.unique(valid_nucs))
1638
- total_nucs = len(valid_nucs)
1639
-
1640
- # Higher diversity = lower likelihood of simple evolution
1641
- diversity_factor = unique_nucs / 4.0 # Normalize by 4 nucleotides
1642
-
1643
- # Simple likelihood model
1644
- likelihood = np.exp(-diversity_factor * total_nucs * 0.1)
1645
-
1646
- return max(likelihood, 1e-10) # Avoid zero likelihood
1647
-
1648
- except Exception:
1649
- return 1e-10 # Safe fallback
1650
-
1651
- def perform_ml_analysis_safe(self, matched_ids: List[str]) -> Dict:
1652
-
1653
- try:
1654
- print("\n🧬 PERFORMING MAXIMUM LIKELIHOOD ANALYSIS")
1655
- print("="*50)
1656
-
1657
- # Include query sequence in analysis
1658
- all_sequences = [self.query_id] + [seq_id for seq_id in matched_ids if seq_id != self.query_id]
1659
-
1660
- # Limit number of sequences to prevent memory issues
1661
- if len(all_sequences) > 20:
1662
- print(f"Warning: Limiting analysis to 20 sequences (had {len(all_sequences)})")
1663
- all_sequences = all_sequences[:20]
1664
-
1665
- if len(all_sequences) < 3:
1666
- print("❌ Need at least 3 sequences for ML analysis")
1667
- return {}
1668
-
1669
- # Step 1: Create multiple sequence alignment
1670
- alignment = self.create_sequence_alignment(all_sequences)
1671
- if not alignment:
1672
- return {}
1673
-
1674
- # Step 2: Calculate ML distances
1675
- distance_matrix = self.calculate_ml_distances(alignment)
1676
- if distance_matrix.size == 0:
1677
- return {}
1678
-
1679
- # Step 3: Construct ML tree
1680
- ml_tree = self.construct_ml_tree(alignment)
1681
- if not ml_tree:
1682
- return {}
1683
-
1684
- # Step 4: Calculate tree likelihood (safely)
1685
- log_likelihood = self.calculate_ml_likelihood_safe(ml_tree, alignment)
1686
-
1687
- # Step 5: Prepare results
1688
- ml_results = {
1689
- 'tree': ml_tree,
1690
- 'alignment': alignment,
1691
- 'distance_matrix': distance_matrix,
1692
- 'log_likelihood': log_likelihood,
1693
- 'sequence_count': len(all_sequences),
1694
- 'alignment_length': len(alignment[0]) if alignment else 0
1695
- }
1696
-
1697
- print(f"✅ ML analysis completed successfully")
1698
- print(f" Sequences analyzed: {len(all_sequences)}")
1699
- print(f" Alignment length: {ml_results['alignment_length']}")
1700
- print(f" Log-likelihood: {log_likelihood:.2f}")
1701
-
1702
- return ml_results
1703
-
1704
- except Exception as e:
1705
- print(f"❌ ML analysis failed: {e}")
1706
- import traceback
1707
- traceback.print_exc()
1708
- return {}
1709
-
1710
- def build_tree_structure_with_ml_safe(self, matched_ids: List[str]) -> Dict:
1711
-
1712
- try:
1713
- print("🌳 Building ML-enhanced tree structure...")
1714
-
1715
- # Perform ML analysis first
1716
- ml_results = self.perform_ml_analysis_safe(matched_ids)
1717
-
1718
- # Build the original hierarchical structure
1719
- tree_structure = self.build_tree_structure(matched_ids)
1720
-
1721
- # Enhance with ML information
1722
- if ml_results and 'tree' in ml_results:
1723
- tree_structure['ml_analysis'] = {
1724
- 'log_likelihood': ml_results['log_likelihood'],
1725
- 'sequence_count': ml_results['sequence_count'],
1726
- 'alignment_length': ml_results['alignment_length'],
1727
- 'ml_tree_available': True
1728
- }
1729
-
1730
- # Store ML tree for later use
1731
- self.ml_tree = ml_results['tree']
1732
- self.ml_alignment = ml_results.get('alignment')
1733
-
1734
- print("✓ Tree structure enhanced with ML analysis")
1735
- else:
1736
- tree_structure['ml_analysis'] = {
1737
- 'ml_tree_available': False,
1738
- 'error': 'ML analysis failed'
1739
- }
1740
- print("⚠️ ML analysis failed, using standard tree structure")
1741
-
1742
- return tree_structure
1743
-
1744
- except Exception as e:
1745
- print(f"Error building ML-enhanced tree structure: {e}")
1746
- # Fallback to original method
1747
- try:
1748
- return self.build_tree_structure(matched_ids)
1749
- except Exception as e2:
1750
- print(f"Fallback also failed: {e2}")
1751
- return {'error': 'Both ML and standard tree construction failed'}
1752
-
1753
-
1754
- def _print_tree_topology(self, tree, max_depth=3, current_depth=0, prefix=""):
1755
-
1756
- if current_depth > max_depth:
1757
- return
1758
-
1759
- try:
1760
- # Get all clades at current level
1761
- clades = list(tree.find_clades())
1762
-
1763
- for i, clade in enumerate(clades[:5]): # Limit to first 5 for readability
1764
- branch_info = ""
1765
- if clade.branch_length is not None:
1766
- branch_info = f" (len: {clade.branch_length:.4f})"
1767
-
1768
- if clade.is_terminal():
1769
- node_name = clade.name or "Terminal"
1770
- print(f" {prefix}├── {node_name}{branch_info}")
1771
- else:
1772
- node_name = clade.name or f"Internal_{i}"
1773
- print(f" {prefix}├── {node_name}{branch_info}")
1774
-
1775
- if current_depth < max_depth - 1 and not clade.is_terminal():
1776
- # Show children (simplified)
1777
- children = list(clade.find_clades())
1778
- if len(children) > 1:
1779
- for j, child in enumerate(children[1:3]): # Show max 2 children
1780
- child_name = child.name or f"Node_{j}"
1781
- child_branch = f" (len: {child.branch_length:.4f})" if child.branch_length else ""
1782
- print(f" {prefix}│ ├── {child_name}{child_branch}")
1783
-
1784
- except Exception as e:
1785
- print(f" Error displaying topology: {e}")
1786
-
1787
-
1788
-
1789
- def main():
1790
- print("\n" + "="*70)
1791
- print("🧬 PHYLOGENETIC TREE ANALYZER - ADVANCED ML-BASED ANALYSIS")
1792
- print("="*70)
1793
- print("Version 2.0 | AI-Enhanced Similarity Matching")
1794
- print("Interactive Visualization with Variable Line Lengths")
1795
- print("="*70)
1796
-
1797
- # Initialize the analyzer
1798
- analyzer = PhylogeneticTreeAnalyzer()
1799
-
1800
- try:
1801
- # Step 1: Load data
1802
- while True:
1803
- data_file = "f cleaned.csv"
1804
- if not data_file:
1805
- print("❌ Please provide a file path.")
1806
- continue
1807
-
1808
- if not Path(data_file).exists():
1809
- print(f"❌ File not found: {data_file}")
1810
- continue
1811
-
1812
- if analyzer.load_data(data_file):
1813
- break
1814
- else:
1815
- print("❌ Failed to load data. Please check file format.")
1816
- continue
1817
-
1818
- # Step 2: Train AI model automatically
1819
- print("\n⏳ Training AI model... This may take a few moments.", flush=True)
1820
- start_time = time.time()
1821
- if analyzer.train_ai_model():
1822
- elapsed = time.time() - start_time
1823
- print(f"✅ AI model training completed in {elapsed:.1f} seconds", flush=True)
1824
- else:
1825
- print("⚠️ AI model training failed, continuing with basic analysis", flush=True)
1826
-
1827
- # Step 3: Get query sequence
1828
- while True:
1829
- print("\n🔍 QUERY SEQUENCE INPUT:")
1830
- print(" You can provide:")
1831
- print(" 1. Accession Number (e.g., 'MH087032') - from your dataset")
1832
- print(" 2. ANY F-gene nucleotide sequence (A, T, G, C)")
1833
- print(" 3. Novel sequences will be compared against your dataset")
1834
- print(" Note: Minimum sequence length is 10 nucleotides")
1835
-
1836
- query_input = input("\nEnter query sequence or ID: ").strip()
1837
- if not query_input:
1838
- print("❌ Please provide a query sequence or ID.")
1839
- continue
1840
-
1841
- if analyzer.find_query_sequence(query_input):
1842
- break
1843
- else:
1844
- retry = input("❌ Invalid input. Try again? (y/n): ").strip().lower()
1845
- if retry != 'y':
1846
- print("👋 Analysis cancelled.")
1847
- return
1848
-
1849
- # Step 4: Set similarity percentage
1850
- while True:
1851
- try:
1852
- print(f"\n📊 SIMILARITY THRESHOLD:")
1853
- print(f" - Higher values (90-99%): Find very similar sequences")
1854
- print(f" - Lower values (70-89%): Find more distantly related sequences")
1855
-
1856
- similarity_input = input(f"Enter target similarity percentage (1-99) [85]: ").strip()
1857
- if not similarity_input:
1858
- target_percentage = 85.0 # Lowered default for novel sequences
1859
- else:
1860
- target_percentage = float(similarity_input)
1861
-
1862
- if not (1 <= target_percentage <= 99):
1863
- print("❌ Please enter a percentage between 1 and 99.")
1864
- continue
1865
-
1866
- analyzer.matching_percentage = target_percentage
1867
- break
1868
-
1869
- except ValueError:
1870
- print("❌ Please enter a valid number.")
1871
- continue
1872
-
1873
- # Step 5: Find similar sequences
1874
- print(f"\n⏳ Analyzing sequences for {target_percentage}% similarity...")
1875
- start_time = time.time()
1876
-
1877
- matched_ids, actual_percentage = analyzer.find_similar_sequences(target_percentage)
1878
-
1879
- if not matched_ids:
1880
- print(f"❌ No similar sequences found at {target_percentage}% similarity.")
1881
- print("💡 Try lowering the similarity percentage (e.g., 70-80%) to find more distant matches.")
1882
- return
1883
-
1884
- analyzer.matched_sequences = matched_ids
1885
- analyzer.actual_percentage = actual_percentage
1886
-
1887
- elapsed = time.time() - start_time
1888
- print(f"✅ Similarity analysis completed in {elapsed:.1f} seconds")
1889
-
1890
- # Step 6: Build tree structure
1891
- print("\n⏳ Building phylogenetic tree structure...")
1892
- start_time = time.time()
1893
-
1894
- tree_structure = analyzer.build_tree_structure_with_ml_safe(matched_ids)
1895
- if not tree_structure:
1896
- print("❌ Failed to build tree structure.")
1897
- return
1898
-
1899
- elapsed = time.time() - start_time
1900
- print(f"✅ Tree structure built in {elapsed:.1f} seconds")
1901
-
1902
- # Step 7: Create visualization and save HTML
1903
- print("\n⏳ Creating interactive visualization...")
1904
- start_time = time.time()
1905
-
1906
- fig = analyzer.create_interactive_tree(matched_ids, actual_percentage)
1907
- if fig:
1908
- elapsed = time.time() - start_time
1909
- print(f"✅ Visualization created in {elapsed:.1f} seconds")
1910
-
1911
- # Save the interactive HTML file
1912
- html_filename = "phylogenetic_tree_interactive.html"
1913
- fig.write_html(html_filename)
1914
- print(f"📄 Interactive HTML saved: {html_filename}")
1915
-
1916
- print(f"\n🎉 Analysis completed successfully!")
1917
- print(f" Query ID: {analyzer.query_id}")
1918
- print(f" Query sequence length: {len(analyzer.query_sequence)} nucleotides")
1919
- print(f" Similar sequences found: {len(matched_ids)}")
1920
- print(f" Actual similarity percentage: {actual_percentage:.1f}%")
1921
- print(f" HTML file generated: {html_filename}")
1922
- else:
1923
- print("❌ Visualization creation failed.")
1924
- return
1925
-
1926
- except KeyboardInterrupt:
1927
- print(f"\n\n⚠️ Analysis interrupted by user.")
1928
- sys.exit(1)
1929
- except Exception as e:
1930
- print(f"\n❌ An error occurred during analysis: {e}")
1931
- print(f"Please check your input data and try again.")
1932
- sys.exit(1)
1933
-
1934
-
1935
- def command_line_interface():
1936
- parser = argparse.ArgumentParser(
1937
- description="Advanced Phylogenetic Tree Analyzer with AI-enhanced similarity matching",
1938
- formatter_class=argparse.RawDescriptionHelpFormatter,
1939
- epilog="""
1940
- Examples:
1941
- # %(prog)s -d data.csv -q MH087032 -s 95
1942
- # %(prog)s -d data.csv -q MH087032 -s 90 --no-ai --batch query1,query2,query3
1943
- """
1944
- )
1945
-
1946
- parser.add_argument('-d', '--data', required=True,
1947
- help='Path to CSV data file')
1948
- parser.add_argument('-q', '--query', required=True,
1949
- help='Query sequence ID or nucleotide sequence')
1950
- parser.add_argument('-s', '--similarity', type=float, default=95.0,
1951
- help='Target similarity percentage (70-99, default: 95)')
1952
- parser.add_argument('--no-ai', action='store_true',
1953
- help='Skip AI model training')
1954
- parser.add_argument('--batch',
1955
- help='Comma-separated list of query IDs for batch processing')
1956
- parser.add_argument('--output-dir', default='.',
1957
- help='Output directory for results')
1958
- parser.add_argument('--save-json', action='store_true',
1959
- help='Save detailed results to JSON')
1960
-
1961
- args = parser.parse_args()
1962
-
1963
- # Validate arguments
1964
- if not (70 <= args.similarity <= 99):
1965
- print("❌ Similarity percentage must be between 70 and 99.")
1966
- sys.exit(1)
1967
-
1968
- if not Path(args.data).exists():
1969
- print(f"❌ Data file not found: {args.data}")
1970
- sys.exit(1)
1971
-
1972
- # Initialize analyzer
1973
- analyzer = PhylogeneticTreeAnalyzer()
1974
-
1975
- # Load data
1976
- if not analyzer.load_data(args.data):
1977
- print("❌ Failed to load data.")
1978
- sys.exit(1)
1979
-
1980
- # Train AI model (unless disabled)
1981
- if not args.no_ai:
1982
- print("\n⏳ Training AI model... This may take a few moments.", flush=True)
1983
- start_time = time.time()
1984
- if analyzer.train_ai_model():
1985
- elapsed = time.time() - start_time
1986
- print(f"✅ AI model training completed in {elapsed:.1f} seconds", flush=True)
1987
- else:
1988
- print("⚠️ AI model training failed, continuing with basic analysis", flush=True)
1989
-
1990
- # Process queries
1991
- queries = args.batch.split(',') if args.batch else [args.query]
1992
-
1993
- for query in queries:
1994
- query = query.strip()
1995
- print(f"\n🔍 Processing: {query}")
1996
-
1997
- if analyzer.find_query_sequence(query):
1998
- matched_ids, actual_percentage = analyzer.find_similar_sequences(args.similarity)
1999
-
2000
- if matched_ids:
2001
- analyzer.build_tree_structure_with_ml_safe(matched_ids)
2002
- fig = analyzer.create_interactive_tree(matched_ids, actual_percentage)
2003
-
2004
- if fig:
2005
- # Save the interactive HTML file
2006
- html_filename = f"phylogenetic_tree_{query.replace('/', '_')}_interactive.html"
2007
- fig.write_html(html_filename)
2008
- print(f"📄 Interactive HTML saved: {html_filename}")
2009
-
2010
- print(f"✅ Analysis completed for {query}")
2011
- else:
2012
- print(f"❌ No similar sequences found for {query}")
2013
- else:
2014
- print(f"❌ Query not found: {query}")
2015
-
2016
-
2017
- if __name__ == "__main__":
2018
- try:
2019
- main()
2020
- except KeyboardInterrupt:
2021
- print(f"\n\n👋 Goodbye!")
2022
- sys.exit(0)
2023
- except Exception as e:
2024
- print(f"\n❌ Unexpected error: {e}")
2025
- sys.exit(1)
2026
- #KR815908