diff --git "a/analyzer.py" "b/analyzer.py"
--- "a/analyzer.py"
+++ "b/analyzer.py"
@@ -1,39 +1,31 @@
import pandas as pd
import numpy as np
import plotly.graph_objects as go
-import plotly.offline as pyo
-from plotly.subplots import make_subplots
-from Bio import Phylo, SeqIO, AlignIO
-from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor
+from Bio import SeqIO, AlignIO
+from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor, DistanceMatrix
+from Bio.Phylo.BaseTree import Tree
from Bio.Align import MultipleSeqAlignment
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
-from sklearn.feature_extraction.text import CountVectorizer
-from sklearn.metrics.pairwise import cosine_similarity
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import warnings
import os
import sys
-from typing import Dict, List, Tuple, Optional, Any
-import json
import re
-from scipy.optimize import minimize
-from scipy.spatial.distance import pdist, squareform
-from Bio.Phylo import BaseTree
-import itertools
-from collections import defaultdict, deque
-import argparse
import time
from pathlib import Path
+from typing import Dict, List, Tuple, Optional
+import itertools
+import argparse
warnings.filterwarnings('ignore')
class PhylogeneticTreeAnalyzer:
+ """Analyzes phylogenetic relationships using ML-based sequence similarity and tree construction."""
def __init__(self):
-
self.data = None
self.query_sequence = None
self.query_id = None
@@ -42,976 +34,649 @@ class PhylogeneticTreeAnalyzer:
self.matched_sequences = []
self.tree_structure = {}
self.similarity_scores = {}
- self.ai_model = None
- self.label_encoder = LabelEncoder()
- # ML-specific attributes
+ self.ai_model = None # ML model for sequence classification
+ self.genotype_model = None # Model for genotype prediction
+ self.label_encoder = LabelEncoder() # Encoder for ML labels
+ self.genotype_label_encoder = LabelEncoder() # Encoder for genotype labels
self.ml_tree = None
self.ml_alignment = None
self.ml_results = {}
- self.horizontal_line_tracker = [] # Track horizontal lines with verticals
- self.query_ml_group = None # Track which ML group contains the query
- self.base_horizontal_length = 1.2 # Base length for horizontal lines
-
- def load_data(self, data_file: str):
-
+ self.horizontal_line_tracker = []
+ self.query_ml_group = None
+ self.base_horizontal_length = 1.2
+ self.ml_model_accuracy = None # Accuracy of ML model
+ self.genotype_model_accuracy = None # Accuracy of genotype model
+
+ # --- Data Loading ---
+ def load_data(self, data_file: str) -> bool:
+ """Loads sequence data from a CSV file."""
try:
self.data = pd.read_csv(data_file)
- # required_columns = ['Accession Number', 'ML', 'Genotype', 'Host',
- # 'Country', 'Isolate', 'Year', 'F-gene']
-
- # missing_columns = [col for col in self.data.columns if col not in required_columns] # Corrected check for missing columns
- # if missing_columns:
- # print(f"Error: Missing required columns: {missing_columns}")
- # return False
-
- print(f"✓ Data loaded successfully: {len(self.data)} sequences")
- print(f"✓ ML Groups found: {self.data['ML'].nunique()}")
- print(f"✓ Genotypes found: {self.data['Genotype'].nunique()}")
+ print(f"✓ Data loaded: {len(self.data)} sequences, "
+ f"{self.data['ML'].nunique()} ML groups, "
+ f"{self.data['Genotype'].nunique()} genotypes")
return True
-
except Exception as e:
print(f"Error loading data: {e}")
return False
-
- def calculate_f_gene_similarity(self, seq1: str, seq2: str) -> float:
-
- try:
- # Handle empty or None sequences
- if not seq1 or not seq2:
- return 0.0
-
- # Convert to uppercase and remove non-nucleotide characters
- seq1 = re.sub(r'[^ATGC]', '', str(seq1).upper())
- seq2 = re.sub(r'[^ATGC]', '', str(seq2).upper())
-
- if len(seq1) == 0 or len(seq2) == 0:
- return 0.0
-
- # Use k-mer analysis for similarity calculation
- k = 5 # 5-mer analysis
- kmers1 = set([seq1[i:i+k] for i in range(len(seq1)-k+1) if len(seq1[i:i+k]) == k])
- kmers2 = set([seq2[i:i+k] for i in range(len(seq2)-k+1) if len(seq2[i:i+k]) == k])
-
- if len(kmers1) == 0 and len(kmers2) == 0:
- return 100.0
- elif len(kmers1) == 0 or len(kmers2) == 0:
- return 0.0
-
- # Calculate Jaccard similarity
- intersection = len(kmers1.intersection(kmers2))
- union = len(kmers1.union(kmers2))
- similarity = (intersection / union) * 100 if union > 0 else 0.0
-
- return round(similarity, 2)
-
- except Exception as e:
- print(f"Error calculating similarity: {e}")
- return 0.0
-
- def train_ai_model(self):
-
+ # --- Model Training ---
+ def train_ai_model(self) -> bool:
+ """Trains RandomForest models for ML group and genotype prediction."""
try:
-
- # Skip training if insufficient data
- if len(self.data) < 10: # Require minimum 10 samples
- print("⚠️ Insufficient data to train AI model (min 10 samples required)", flush=True)
+ if len(self.data) < 10:
+ print("⚠️ Insufficient data for training (minimum 10 samples)")
return False
- print("🤖 Training AI model for sequence analysis...", flush=True)
-
- # Prepare features from F-gene sequences
+ print("🤖 Training AI models...")
f_gene_sequences = self.data['F-gene'].fillna('').astype(str)
-
- # Create k-mer features (3-mers to 6-mers)
features = []
for seq in f_gene_sequences:
seq_clean = re.sub(r'[^ATGC]', '', seq.upper())
if len(seq_clean) < 3:
- features.append([0] * 100) # Placeholder for short sequences
+ features.append([0] * 100)
continue
-
feature_vector = []
- # 3-mers
kmers_3 = [seq_clean[i:i+3] for i in range(len(seq_clean)-2)]
kmer_counts_3 = {kmer: kmers_3.count(kmer) for kmer in set(kmers_3)}
-
- # 4-mers
kmers_4 = [seq_clean[i:i+4] for i in range(len(seq_clean)-3)]
kmer_counts_4 = {kmer: kmers_4.count(kmer) for kmer in set(kmers_4)}
-
- # Create feature vector (top 50 3-mers + top 50 4-mers)
- all_3mers = [''.join(p) for p in __import__('itertools').product('ATGC', repeat=3)]
- all_4mers = [''.join(p) for p in __import__('itertools').product('ATGC', repeat=4)]
-
+ all_3mers = [''.join(p) for p in itertools.product('ATGC', repeat=3)]
+ all_4mers = [''.join(p) for p in itertools.product('ATGC', repeat=4)]
feature_vector.extend([kmer_counts_3.get(kmer, 0) for kmer in all_3mers[:50]])
feature_vector.extend([kmer_counts_4.get(kmer, 0) for kmer in all_4mers[:50]])
-
features.append(feature_vector)
- # Prepare target labels (ML groups)
- targets = self.label_encoder.fit_transform(self.data['ML'].fillna('Unknown'))
-
- # Skip if only 1 class
- if len(np.unique(targets)) < 2:
- print("⚠️ Need at least 2 distinct classes for training", flush=True)
- return False
-
- # Train Random Forest model
X = np.array(features)
- y = targets
-
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
+ # Train ML model
+ ml_targets = self.label_encoder.fit_transform(self.data['ML'].fillna('Unknown'))
+ if len(np.unique(ml_targets)) < 2:
+ print("⚠️ Need at least 2 ML classes for training")
+ return False
+ X_train, X_test, y_train, y_test = train_test_split(X, ml_targets, test_size=0.2, random_state=42)
self.ai_model = RandomForestClassifier(n_estimators=100, random_state=42)
self.ai_model.fit(X_train, y_train)
-
- # Calculate accuracy
- accuracy = self.ai_model.score(X_test, y_test)
- print(f"✓ AI model trained successfully with accuracy: {accuracy:.2%}", flush=True)
+ self.ml_model_accuracy = self.ai_model.score(X_test, y_test)
+ print(f"✓ ML model trained with accuracy: {self.ml_model_accuracy:.2%}")
+
+ # Train genotype model
+ genotype_targets = self.genotype_label_encoder.fit_transform(self.data['Genotype'].fillna('Unknown'))
+ if len(np.unique(genotype_targets)) >= 2:
+ X_train, X_test, y_train, y_test = train_test_split(X, genotype_targets, test_size=0.2, random_state=42)
+ self.genotype_model = RandomForestClassifier(n_estimators=100, random_state=42)
+ self.genotype_model.fit(X_train, y_train)
+ self.genotype_model_accuracy = self.genotype_model.score(X_test, y_test)
+ print(f"✓ Genotype model trained with accuracy: {self.genotype_model_accuracy:.2%}")
return True
-
except Exception as e:
- print(f"🚨 CRITICAL training error: {e}", flush=True)
- import traceback
- traceback.print_exc()
+ print(f"Error training models: {e}")
return False
+ def predict_ml_group(self, sequence: str) -> str:
+ """Predicts ML group for a sequence using the trained model."""
+ try:
+ if not self.ai_model:
+ return "Unknown"
+ seq_clean = re.sub(r'[^ATGC]', '', sequence.upper())
+ if len(seq_clean) < 3:
+ return "Unknown"
+ feature_vector = []
+ kmers_3 = [seq_clean[i:i+3] for i in range(len(seq_clean)-2)]
+ kmer_counts_3 = {kmer: kmers_3.count(kmer) for kmer in set(kmers_3)}
+ kmers_4 = [seq_clean[i:i+4] for i in range(len(seq_clean)-3)]
+ kmer_counts_4 = {kmer: kmers_4.count(kmer) for kmer in set(kmers_4)}
+ all_3mers = [''.join(p) for p in itertools.product('ATGC', repeat=3)]
+ all_4mers = [''.join(p) for p in itertools.product('ATGC', repeat=4)]
+ feature_vector.extend([kmer_counts_3.get(kmer, 0) for kmer in all_3mers[:50]])
+ feature_vector.extend([kmer_counts_4.get(kmer, 0) for kmer in all_4mers[:50]])
+ X = np.array([feature_vector])
+ ml_pred = self.label_encoder.inverse_transform(self.ai_model.predict(X))[0]
+ return ml_pred
+ except Exception as e:
+ print(f"Error predicting ML group: {e}")
+ return "Unknown"
+
+ def predict_genotype(self, sequence: str) -> str:
+ """Predicts genotype for a sequence using the trained model."""
+ try:
+ if not self.genotype_model:
+ return "Unknown"
+ seq_clean = re.sub(r'[^ATGC]', '', sequence.upper())
+ if len(seq_clean) < 3:
+ return "Unknown"
+ feature_vector = []
+ kmers_3 = [seq_clean[i:i+3] for i in range(len(seq_clean)-2)]
+ kmer_counts_3 = {kmer: kmers_3.count(kmer) for kmer in set(kmers_3)}
+ kmers_4 = [seq_clean[i:i+4] for i in range(len(seq_clean)-3)]
+ kmer_counts_4 = {kmer: kmers_4.count(kmer) for kmer in set(kmers_4)}
+ all_3mers = [''.join(p) for p in itertools.product('ATGC', repeat=3)]
+ all_4mers = [''.join(p) for p in itertools.product('ATGC', repeat=4)]
+ feature_vector.extend([kmer_counts_3.get(kmer, 0) for kmer in all_3mers[:50]])
+ feature_vector.extend([kmer_counts_4.get(kmer, 0) for kmer in all_4mers[:50]])
+ X = np.array([feature_vector])
+ genotype_pred = self.genotype_label_encoder.inverse_transform(self.genotype_model.predict(X))[0]
+ return genotype_pred
+ except Exception as e:
+ print(f"Error predicting genotype: {e}")
+ return "Unknown"
+
+ # --- Sequence Processing ---
def find_query_sequence(self, query_input: str) -> bool:
- """
- Modified to match the sequence parameter against both [Accession Number] and [F-gene] columns in the dataset.
- """
+ """Identifies query sequence by accession number, F-gene, or as a novel sequence."""
try:
- # Check if input is an accession number from the dataset
+ query_input = query_input.strip()
if query_input in self.data['Accession Number'].values:
self.query_id = query_input
query_row = self.data[self.data['Accession Number'] == query_input].iloc[0]
self.query_sequence = query_row['F-gene']
- print(f"✓ Query sequence found by ID: {query_input}")
+ print(f"✓ Query found by accession: {query_input}, ML: {query_row['ML']}, Genotype: {query_row['Genotype']}")
return True
-
- # Check if input matches an F-gene sequence in the dataset
query_clean = re.sub(r'[^ATGC]', '', str(query_input).upper())
if query_clean in self.data['F-gene'].values:
query_row = self.data[self.data['F-gene'] == query_clean].iloc[0]
self.query_id = query_row['Accession Number']
self.query_sequence = query_clean
- print(f"✓ Query sequence found in F-gene column: {query_clean}")
+ print(f"✓ Query matched to accession: {self.query_id}, ML: {query_row['ML']}, Genotype: {query_row['Genotype']}")
return True
-
- # Accept any sequence with reasonable length (even short ones for testing)
- if len(query_clean) >= 10: # Minimum sequence length
- # For sequences not in dataset, create a unique identifier
+ if len(query_clean) >= 10:
self.query_id = f"QUERY_{hash(query_clean) % 100000:05d}"
self.query_sequence = query_clean
- print(f"✓ Novel query sequence accepted with ID: {self.query_id}")
- print(f" Sequence length: {len(query_clean)} nucleotides")
+ predicted_ml = self.predict_ml_group(query_clean)
+ predicted_genotype = self.predict_genotype(query_clean)
+ print(f"✓ Novel query accepted: {self.query_id}, Length: {len(query_clean)}, "
+ f"Predicted ML: {predicted_ml}, Predicted Genotype: {predicted_genotype}")
return True
-
- # If sequence is too short or invalid
- if len(query_clean) < 10:
- print(f"✗ Query sequence too short. Minimum length: 10 nucleotides (provided: {len(query_clean)})")
- else:
- print(f"✗ Invalid sequence format. Please provide nucleotides (A, T, G, C) or valid accession number")
-
+ print(f"✗ Invalid query: Too short (<10) or not found")
return False
-
except Exception as e:
- print(f"Error processing query sequence: {e}")
+ print(f"Error processing query: {e}")
return False
+ def calculate_f_gene_similarity(self, seq1: str, seq2: str) -> float:
+ """Calculates similarity between two sequences using k-mer analysis."""
+ try:
+ if not seq1 or not seq2:
+ return 0.0
+ seq1 = re.sub(r'[^ATGC]', '', str(seq1).upper())
+ seq2 = re.sub(r'[^ATGC]', '', str(seq2).upper())
+ if len(seq1) == 0 or len(seq2) == 0:
+ return 0.0
+ k = 5
+ kmers1 = set(seq1[i:i+k] for i in range(len(seq1)-k+1) if len(seq1[i:i+k]) == k)
+ kmers2 = set(seq2[i:i+k] for i in range(len(seq2)-k+1) if len(seq2[i:i+k]) == k)
+ if len(kmers1) == 0 and len(kmers2) == 0:
+ return 100.0
+ if len(kmers1) == 0 or len(kmers2) == 0:
+ return 0.0
+ intersection = len(kmers1.intersection(kmers2))
+ union = len(kmers1.union(kmers2))
+ return round((intersection / union) * 100, 2) if union > 0 else 0.0
+ except Exception as e:
+ print(f"Error calculating similarity: {e}")
+ return 0.0
+
def find_similar_sequences(self, target_percentage: float) -> Tuple[List[str], float]:
- """
- Modified to work with any query sequence, including novel ones not in the dataset.
- """
+ """Finds sequences similar to the query sequence."""
try:
- print(f"🔍 Finding sequences with {target_percentage}% similarity to query...")
+ print(f"🔍 Finding sequences with {target_percentage}% similarity...")
similarities = []
-
- # Calculate similarity between query and all sequences in dataset
for idx, row in self.data.iterrows():
- # Skip if this is the same sequence (only relevant for existing accession numbers)
- if hasattr(self, 'query_id') and row['Accession Number'] == self.query_id:
+ if row['Accession Number'] == self.query_id:
continue
-
- try:
- similarity = self.calculate_f_gene_similarity(self.query_sequence, row['F-gene'])
- similarities.append({
- 'id': row['Accession Number'],
- 'similarity': similarity,
- 'ml': row['ML'] if 'ML' in row else 'Unknown',
- 'genotype': row['Genotype'] if 'Genotype' in row else 'Unknown'
- })
- except Exception as seq_error:
- print(f"⚠ Skipping sequence {row['Accession Number']}: {seq_error}")
- continue
-
+ similarity = self.calculate_f_gene_similarity(self.query_sequence, row['F-gene'])
+ similarities.append({
+ 'id': row['Accession Number'],
+ 'similarity': similarity,
+ 'ml': row.get('ML', 'Unknown'),
+ 'genotype': row.get('Genotype', 'Unknown')
+ })
if not similarities:
- print("❌ No valid sequences found for comparison")
+ print("❌ No valid sequences for comparison")
return [], target_percentage
-
- # Sort by similarity (highest first)
similarities.sort(key=lambda x: x['similarity'], reverse=True)
-
- # Find sequences within target percentage range (±2%)
target_range = 2.0
- candidates = [s for s in similarities
- if abs(s['similarity'] - target_percentage) <= target_range]
-
+ candidates = [s for s in similarities if abs(s['similarity'] - target_percentage) <= target_range]
if not candidates:
- # If no exact matches, find sequences with closest similarity
- closest_sim = min(similarities, key=lambda x: abs(x['similarity'] - target_percentage))
- actual_percentage = closest_sim['similarity']
-
- # Get sequences within ±1% of the closest similarity
- candidates = [s for s in similarities
- if abs(s['similarity'] - actual_percentage) <= 1.0]
-
- print(f"⚠ No sequences found at exactly {target_percentage}%. Using closest: {actual_percentage:.1f}%")
+ closest = min(similarities, key=lambda x: abs(x['similarity'] - target_percentage))
+ actual_percentage = closest['similarity']
+ candidates = [s for s in similarities if abs(s['similarity'] - actual_percentage) <= 1.0]
+ print(f"⚠ No sequences at {target_percentage}%. Using closest: {actual_percentage:.1f}%")
else:
actual_percentage = target_percentage
-
- # Limit results to prevent overwhelming visualization (optional)
- max_results = 50 # Adjust as needed
+ max_results = 50
if len(candidates) > max_results:
candidates = candidates[:max_results]
- print(f"⚠ Limited results to top {max_results} matches for better visualization")
-
- # Store similarity scores for later use
- self.similarity_scores = {} # Reset similarity scores
- for candidate in candidates:
- self.similarity_scores[candidate['id']] = candidate['similarity']
-
+ print(f"⚠ Limited to top {max_results} matches")
+ self.similarity_scores = {c['id']: c['similarity'] for c in candidates}
matched_ids = [c['id'] for c in candidates]
-
- # Show some statistics
if similarities:
- max_sim = max(similarities, key=lambda x: x['similarity'])['similarity']
- min_sim = min(similarities, key=lambda x: x['similarity'])['similarity']
+ max_sim = max(s['similarity'] for s in similarities)
+ min_sim = min(s['similarity'] for s in similarities)
avg_sim = sum(s['similarity'] for s in similarities) / len(similarities)
-
- print(f"✓ Found {len(matched_ids)} sequences at ~{actual_percentage:.1f}% similarity")
- print(f" Similarity range in dataset: {min_sim:.1f}% - {max_sim:.1f}% (avg: {avg_sim:.1f}%)")
-
+ print(f"✓ Found {len(matched_ids)} sequences at ~{actual_percentage:.1f}% similarity, "
+ f"Range: {min_sim:.1f}% - {max_sim:.1f}% (avg: {avg_sim:.1f}%)")
return matched_ids, actual_percentage
-
except Exception as e:
print(f"Error finding similar sequences: {e}")
return [], target_percentage
-
+ # --- Tree Construction ---
def build_tree_structure(self, matched_ids: List[str]) -> Dict:
- try:
- print("🌳 Building normalized horizontal tree structure...")
-
- # Initialize tree structure
- tree_structure = {
- 'root': {
- 'name': 'Root',
- 'type': 'root',
- 'children': {},
- 'x': 0,
- 'y': 0,
- 'has_vertical_attachment': False,
- 'extension_level': 0
- }
- }
-
- # Group sequences by ML and Genotype
- ml_groups = {}
- for idx, row in self.data.iterrows():
- ml_group = row['ML']
- genotype = row['Genotype']
- seq_id = row['Accession Number']
-
- if ml_group not in ml_groups:
- ml_groups[ml_group] = {}
-
- if genotype not in ml_groups[ml_group]:
- ml_groups[ml_group][genotype] = []
-
- ml_groups[ml_group][genotype].append({
- 'id': seq_id,
- 'data': row.to_dict(),
- 'is_query': seq_id == self.query_id,
- 'is_matched': seq_id in matched_ids,
- 'similarity': self.similarity_scores.get(seq_id, 0.0)
- })
-
- # Normalize ML group names and organize
- normalized_ml_groups = self._normalize_ml_groups(ml_groups)
-
- # Build normalized ML level - horizontal layout with progressive extensions
- self._build_normalized_ml_nodes(tree_structure, normalized_ml_groups, matched_ids)
-
- self.tree_structure = tree_structure
- print(f"✓ Normalized horizontal tree structure built")
- return tree_structure
+ """Builds a hierarchical tree structure based on ML groups and genotypes."""
+ try:
+ print("🌳 Building normalized tree structure...")
+ tree_structure = {
+ 'root': {'name': 'Root', 'type': 'root', 'children': {}, 'x': 0, 'y': 0,
+ 'has_vertical_attachment': False, 'extension_level': 0}
+ }
+ ml_groups = {}
+ for idx, row in self.data.iterrows():
+ ml_group = row['ML']
+ genotype = row['Genotype']
+ seq_id = row['Accession Number']
+ if ml_group not in ml_groups:
+ ml_groups[ml_group] = {}
+ if genotype not in ml_groups[ml_group]:
+ ml_groups[ml_group][genotype] = []
+ ml_groups[ml_group][genotype].append({
+ 'id': seq_id, 'data': row.to_dict(), 'is_query': seq_id == self.query_id,
+ 'is_matched': seq_id in matched_ids, 'similarity': self.similarity_scores.get(seq_id, 0.0)
+ })
+ if self.query_id.startswith("QUERY_"):
+ predicted_ml = self.predict_ml_group(self.query_sequence)
+ predicted_genotype = self.predict_genotype(self.query_sequence)
+ if predicted_ml not in ml_groups:
+ ml_groups[predicted_ml] = {}
+ if predicted_genotype not in ml_groups[predicted_ml]:
+ ml_groups[predicted_ml][predicted_genotype] = []
+ ml_groups[predicted_ml][predicted_genotype].append({
+ 'id': self.query_id, 'data': {
+ 'F-gene': self.query_sequence, 'ML': predicted_ml, 'Genotype': predicted_genotype,
+ 'Accession Number': self.query_id
+ }, 'is_query': True, 'is_matched': False, 'similarity': 100.0
+ })
+ normalized_ml_groups = self._normalize_ml_groups(ml_groups)
+ self._build_normalized_ml_nodes(tree_structure, normalized_ml_groups, matched_ids)
+ self.tree_structure = tree_structure
+ print("✓ Tree structure built")
+ return tree_structure
+ except Exception as e:
+ print(f"Error building tree structure: {e}")
+ return {}
- except Exception as e:
- print(f"Error building tree structure: {e}")
- return {}
+ def build_tree_structure_with_ml_safe(self, matched_ids: List[str]) -> Dict:
+ """Enhances tree structure with ML analysis."""
+ try:
+ print("🌳 Building ML-enhanced tree structure...")
+ ml_results = self.perform_ml_analysis_safe(matched_ids)
+ tree_structure = self.build_tree_structure(matched_ids)
+ if ml_results and 'tree' in ml_results:
+ tree_structure['ml_analysis'] = {
+ 'log_likelihood': ml_results['log_likelihood'],
+ 'sequence_count': ml_results['sequence_count'],
+ 'alignment_length': ml_results['alignment_length'],
+ 'ml_tree_available': True
+ }
+ self.ml_tree = ml_results['tree']
+ self.ml_alignment = ml_results.get('alignment')
+ print("✓ Tree enhanced with ML analysis")
+ else:
+ tree_structure['ml_analysis'] = {'ml_tree_available': False, 'error': 'ML analysis failed'}
+ print("⚠ ML analysis failed, using standard tree")
+ return tree_structure
+ except Exception as e:
+ print(f"Error building ML-enhanced tree: {e}")
+ try:
+ return self.build_tree_structure(matched_ids)
+ except Exception as e2:
+ print(f"Fallback failed: {e2}")
+ return {'error': 'Tree construction failed'}
def _normalize_ml_groups(self, ml_groups: Dict) -> Dict:
- """Normalize ML group names and organize hierarchically"""
+ """Normalizes ML group names for hierarchical organization."""
try:
normalized_groups = {}
-
for ml_name, genotypes in ml_groups.items():
- # Extract base ML name
- if ml_name.startswith('UNCL'):
- base_ml = 'UNCL'
- elif '.' in ml_name and any(char.isdigit() for char in ml_name):
- # For names like XII.1.2, XII.1, etc., extract the base (XII)
- base_ml = ml_name.split('.')[0]
- else:
- base_ml = ml_name
-
- # Initialize normalized group structure
+ base_ml = 'UNCL' if ml_name.startswith('UNCL') else ml_name.split('.')[0] if '.' in ml_name and any(c.isdigit() for c in ml_name) else ml_name
if base_ml not in normalized_groups:
- normalized_groups[base_ml] = {
- 'full_ml_groups': {},
- 'representative_sequences': [],
- 'has_special_sequences': False
- }
-
- # Check if this ML group has query or matched sequences
- has_special = any(
- any(seq['is_query'] or seq['is_matched'] for seq in sequences)
- for sequences in genotypes.values()
- )
-
+ normalized_groups[base_ml] = {'full_ml_groups': {}, 'representative_sequences': [], 'has_special_sequences': False}
+ has_special = any(any(seq['is_query'] or seq['is_matched'] for seq in seqs) for seqs in genotypes.values())
if has_special:
normalized_groups[base_ml]['has_special_sequences'] = True
normalized_groups[base_ml]['full_ml_groups'][ml_name] = genotypes
- else:
- # Add as representative (limit to 2 representatives)
- if len(normalized_groups[base_ml]['representative_sequences']) < 2:
- # Get 1-2 representative sequences from this ML group
- for genotype, sequences in list(genotypes.items())[:2]:
- if len(normalized_groups[base_ml]['representative_sequences']) < 2:
- normalized_groups[base_ml]['representative_sequences'].extend(sequences[:1])
-
+ elif len(normalized_groups[base_ml]['representative_sequences']) < 2:
+ for genotype, sequences in list(genotypes.items())[:2]:
+ if len(normalized_groups[base_ml]['representative_sequences']) < 2:
+ normalized_groups[base_ml]['representative_sequences'].extend(sequences[:1])
return normalized_groups
-
except Exception as e:
print(f"Error normalizing ML groups: {e}")
return {}
def _build_normalized_ml_nodes(self, tree_structure: Dict, normalized_ml_groups: Dict, matched_ids: List[str]):
- """Build normalized ML nodes with equal spacing and progressive horizontal extensions"""
+ """Builds normalized ML nodes with equal spacing."""
try:
- # Reset horizontal line tracker
self.horizontal_line_tracker = []
-
- # Identify which ML group contains the query
self._identify_query_ml_group(normalized_ml_groups)
-
- # Calculate equal spacing for all ML groups
ml_positions = self._calculate_dynamic_ml_positions(normalized_ml_groups)
-
- # Mark root as having vertical attachment if it has multiple children
- root_has_vertical = len(normalized_ml_groups) > 1
- tree_structure['root']['has_vertical_attachment'] = root_has_vertical
-
+ tree_structure['root']['has_vertical_attachment'] = len(normalized_ml_groups) > 1
for ml_idx, (base_ml, ml_data) in enumerate(normalized_ml_groups.items()):
y_pos = ml_positions[ml_idx]
-
- # Determine if this ML node will have vertical attachments
has_vertical = ml_data['has_special_sequences'] and len(ml_data['full_ml_groups']) > 1
-
- # Check if this ML group contains the query
- contains_query = (base_ml == self.query_ml_group)
-
- # Calculate horizontal line length based on connections and query presence
- horizontal_length = self._determine_horizontal_line_length(
- 'normalized_ml_group', has_vertical, contains_query
- )
-
+ contains_query = base_ml == self.query_ml_group
+ horizontal_length = self._determine_horizontal_line_length('normalized_ml_group', has_vertical, contains_query)
x_pos = horizontal_length
-
- # Create normalized ML node
tree_structure['root']['children'][base_ml] = {
- 'name': base_ml,
- 'type': 'normalized_ml_group',
- 'children': {},
- 'x': x_pos,
- 'y': y_pos,
- 'has_special_sequences': ml_data['has_special_sequences'],
- 'has_vertical_attachment': has_vertical,
- 'horizontal_line_length': horizontal_length,
- 'contains_query': contains_query
+ 'name': base_ml, 'type': 'normalized_ml_group', 'children': {}, 'x': x_pos, 'y': y_pos,
+ 'has_special_sequences': ml_data['has_special_sequences'], 'has_vertical_attachment': has_vertical,
+ 'horizontal_line_length': horizontal_length, 'contains_query': contains_query
}
-
if ml_data['has_special_sequences']:
- # Build full ML nodes for groups with special sequences
- self._build_full_ml_nodes(
- tree_structure['root']['children'][base_ml],
- ml_data['full_ml_groups'],
- y_pos,
- matched_ids,
- x_pos
- )
+ self._build_full_ml_nodes(tree_structure['root']['children'][base_ml], ml_data['full_ml_groups'],
+ y_pos, matched_ids, x_pos)
else:
- # Add representative sequences directly
- self._add_representative_sequences(
- tree_structure['root']['children'][base_ml],
- ml_data['representative_sequences'],
- y_pos,
- x_pos
- )
-
+ self._add_representative_sequences(tree_structure['root']['children'][base_ml],
+ ml_data['representative_sequences'], y_pos, x_pos)
except Exception as e:
print(f"Error building normalized ML nodes: {e}")
- def _calculate_dynamic_ml_positions(self, normalized_ml_groups: Dict) -> List[float]:
- """Calculate equal Y positions for all ML groups regardless of content"""
- try:
- ml_count = len(normalized_ml_groups)
- if ml_count == 0:
- return []
-
- if ml_count == 1:
- return [0.0]
-
- # Equal spacing between all ML nodes
- total_spacing = (ml_count - 1) * 2.0 # 2.0 units between each ML node
- start_y = -total_spacing / 2
-
- positions = []
- for i in range(ml_count):
- positions.append(start_y + i * 2.0)
-
- return positions
-
- except Exception as e:
- print(f"Error calculating dynamic positions: {e}")
- return list(range(len(normalized_ml_groups)))
-
def _build_full_ml_nodes(self, normalized_ml_node: Dict, full_ml_groups: Dict, base_y: float, matched_ids: List[str], parent_x: float):
- """Build full ML nodes with genotypes for groups containing special sequences"""
+ """Builds full ML nodes with genotypes."""
try:
- # Calculate equal positions for full ML groups
full_ml_positions = self._calculate_full_ml_positions(full_ml_groups, base_y)
-
for ml_idx, (full_ml_name, genotypes) in enumerate(full_ml_groups.items()):
y_pos = full_ml_positions[ml_idx]
-
- # Determine if this full ML node will have vertical attachments
- special_genotypes_count = sum(1 for genotype, sequences in genotypes.items()
- if any(seq['is_query'] or seq['is_matched'] for seq in sequences))
+ special_genotypes_count = sum(1 for g, seqs in genotypes.items() if any(s['is_query'] or s['is_matched'] for s in seqs))
has_vertical = special_genotypes_count > 1
-
- # Check if this full ML group contains the query
- contains_query = any(
- any(seq['is_query'] for seq in sequences)
- for sequences in genotypes.values()
- )
-
- # Calculate horizontal line length
- horizontal_length = self._determine_horizontal_line_length(
- 'full_ml_group', has_vertical, contains_query
- )
-
+ contains_query = any(any(seq['is_query'] for seq in seqs) for seqs in genotypes.values())
+ horizontal_length = self._determine_horizontal_line_length('full_ml_group', has_vertical, contains_query)
x_pos = parent_x + horizontal_length
-
- # Create full ML node
normalized_ml_node['children'][full_ml_name] = {
- 'name': full_ml_name,
- 'type': 'full_ml_group',
- 'children': {},
- 'x': x_pos,
- 'y': y_pos,
- 'sequences_count': sum(len(seqs) for seqs in genotypes.values()),
- 'has_vertical_attachment': has_vertical,
- 'horizontal_line_length': horizontal_length,
- 'contains_query': contains_query
+ 'name': full_ml_name, 'type': 'full_ml_group', 'children': {}, 'x': x_pos, 'y': y_pos,
+ 'sequences_count': sum(len(seqs) for seqs in genotypes.values()), 'has_vertical_attachment': has_vertical,
+ 'horizontal_line_length': horizontal_length, 'contains_query': contains_query
}
-
- # Build genotype nodes
- self._build_genotype_nodes(
- normalized_ml_node['children'][full_ml_name],
- genotypes,
- y_pos,
- matched_ids,
- x_pos
- )
-
+ self._build_genotype_nodes(normalized_ml_node['children'][full_ml_name], genotypes, y_pos, matched_ids, x_pos)
except Exception as e:
print(f"Error building full ML nodes: {e}")
- def _calculate_full_ml_positions(self, full_ml_groups: Dict, base_y: float) -> List[float]:
- """Calculate equal positions for full ML groups"""
- try:
- ml_count = len(full_ml_groups)
- if ml_count <= 1:
- return [base_y]
-
- # Equal spacing for full ML groups
- spacing = 1.5 # Fixed spacing between full ML groups
- start_y = base_y - (spacing * (ml_count - 1)) / 2
-
- positions = []
- for i in range(ml_count):
- positions.append(start_y + i * spacing)
-
- return positions
-
- except Exception as e:
- print(f"Error calculating full ML positions: {e}")
- return [base_y] * len(full_ml_groups)
-
def _build_genotype_nodes(self, full_ml_node: Dict, genotypes: Dict, base_y: float, matched_ids: List[str], parent_x: float):
- """Build genotype nodes with sequences - horizontal line length based on sequence count"""
+ """Builds genotype nodes with sequences."""
try:
- # Filter genotypes with special sequences
- special_genotypes = []
- for genotype, sequences in genotypes.items():
- if any(seq['is_query'] or seq['is_matched'] for seq in sequences):
- special_genotypes.append((genotype, sequences))
-
+ special_genotypes = [(g, seqs) for g, seqs in genotypes.items() if any(s['is_query'] or s['is_matched'] for s in seqs)]
if not special_genotypes:
return
-
- # Calculate equal genotype positions (vertical positioning remains equal)
genotype_positions = self._calculate_genotype_positions(special_genotypes, base_y)
-
- # Calculate sequence counts for each genotype to determine horizontal line lengths
- genotype_sequence_counts = []
- for genotype, sequences in special_genotypes:
- special_sequences = [seq for seq in sequences if seq['is_query'] or seq['is_matched']]
- genotype_sequence_counts.append((genotype, sequences, len(special_sequences)))
-
+ genotype_sequence_counts = [(g, seqs, len([s for s in seqs if s['is_query'] or s['is_matched']])) for g, seqs in special_genotypes]
for gt_idx, (genotype, sequences, sequence_count) in enumerate(genotype_sequence_counts):
y_pos = genotype_positions[gt_idx]
-
- # Determine if this genotype will have vertical attachments
- special_sequences = [seq for seq in sequences if seq['is_query'] or seq['is_matched']]
+ special_sequences = [s for s in sequences if s['is_query'] or s['is_matched']]
has_vertical = len(special_sequences) > 1
-
- # Check if this genotype contains the query
- contains_query = any(seq['is_query'] for seq in sequences)
-
- # Calculate horizontal line length based on sequence count
- horizontal_length = self._determine_genotype_horizontal_line_length(
- sequence_count, has_vertical, contains_query
- )
-
+ contains_query = any(s['is_query'] for s in sequences)
+ horizontal_length = self._determine_genotype_horizontal_line_length(sequence_count, has_vertical, contains_query)
x_pos = parent_x + horizontal_length
-
- # Create genotype node
full_ml_node['children'][genotype] = {
- 'name': genotype,
- 'type': 'genotype',
- 'children': {},
- 'x': x_pos,
- 'y': y_pos,
- 'sequences': sequences,
- 'has_vertical_attachment': has_vertical,
- 'horizontal_line_length': horizontal_length,
- 'contains_query': contains_query,
- 'sequence_count': sequence_count # Store for reference
+ 'name': genotype, 'type': 'genotype', 'children': {}, 'x': x_pos, 'y': y_pos,
+ 'sequences': sequences, 'has_vertical_attachment': has_vertical,
+ 'horizontal_line_length': horizontal_length, 'contains_query': contains_query,
+ 'sequence_count': sequence_count
}
-
- # Add sequences horizontally
- self._add_sequences_horizontal(
- full_ml_node['children'][genotype],
- sequences,
- y_pos,
- x_pos
- )
-
+ self._add_sequences_horizontal(full_ml_node['children'][genotype], sequences, y_pos, x_pos)
except Exception as e:
print(f"Error building genotype nodes: {e}")
- def _determine_genotype_horizontal_line_length(self, sequence_count: int, has_vertical: bool, contains_query: bool = False) -> float:
- """Determine horizontal line length for genotype nodes based on sequence count"""
- try:
- base_length = self.base_horizontal_length
-
- # Special case: Genotype containing query sequence gets additional length
- query_bonus = 0.5 if contains_query else 0.0
-
- # Calculate length based on sequence count
- # More sequences = longer horizontal line
- if sequence_count <= 1:
- # Single sequence
- length_multiplier = 1.0
- elif sequence_count <= 3:
- # 2-3 sequences
- length_multiplier = 1.6
- elif sequence_count <= 5:
- # 4-5 sequences
- length_multiplier = 2.3
- elif sequence_count <= 8:
- # 6-8 sequences
- length_multiplier = 6.0
- else:
- # More than 8 sequences
- length_multiplier = 6.0
-
- # Calculate final length
- calculated_length = base_length * length_multiplier + query_bonus
-
- return calculated_length
-
- except Exception as e:
- print(f"Error determining genotype horizontal line length: {e}")
- return self.base_horizontal_length
-
- def _calculate_genotype_positions(self, special_genotypes: List, base_y: float) -> List[float]:
- """Calculate equal positions for genotypes"""
- try:
- genotype_count = len(special_genotypes)
- if genotype_count <= 1:
- return [base_y]
-
- # Equal spacing for genotypes
- spacing = 1.0 # Fixed spacing between genotypes
- start_y = base_y - (spacing * (genotype_count - 1)) / 2
-
- positions = []
- for i in range(genotype_count):
- positions.append(start_y + i * spacing)
-
- return positions
-
- except Exception as e:
- print(f"Error calculating genotype positions: {e}")
- return [base_y] * len(special_genotypes)
-
def _add_representative_sequences(self, normalized_ml_node: Dict, representative_sequences: List[Dict], base_y: float, parent_x: float):
- """Add representative sequences directly to normalized ML node"""
+ """Adds representative sequences to normalized ML nodes."""
try:
if not representative_sequences:
return
-
- # Calculate horizontal line length for representative sequences
- # Representative sequences get a standard length (not similarity-based since they're not matched)
has_vertical = len(representative_sequences) > 1
horizontal_length = self._determine_horizontal_line_length('representative', has_vertical)
x_pos = parent_x + horizontal_length
-
if len(representative_sequences) == 1:
seq = representative_sequences[0]
normalized_ml_node['children'][f"{seq['id']}_rep"] = {
- 'name': f"{seq['id']} (Rep)",
- 'type': 'representative_sequence',
- 'data': seq,
- 'x': x_pos,
- 'y': base_y,
- 'has_vertical_attachment': False,
- 'horizontal_line_length': horizontal_length
+ 'name': f"{seq['id']} (Rep)", 'type': 'representative_sequence', 'data': seq,
+ 'x': x_pos, 'y': base_y, 'has_vertical_attachment': False, 'horizontal_line_length': horizontal_length
}
else:
- # Equal spacing for multiple representative sequences
positions = self._calculate_sequence_positions(representative_sequences, base_y)
-
for idx, seq in enumerate(representative_sequences):
normalized_ml_node['children'][f"{seq['id']}_rep"] = {
- 'name': f"{seq['id']} (Rep)",
- 'type': 'representative_sequence',
- 'data': seq,
- 'x': x_pos,
- 'y': positions[idx],
- 'has_vertical_attachment': False,
- 'horizontal_line_length': horizontal_length
+ 'name': f"{seq['id']} (Rep)", 'type': 'representative_sequence', 'data': seq,
+ 'x': x_pos, 'y': positions[idx], 'has_vertical_attachment': False, 'horizontal_line_length': horizontal_length
}
-
except Exception as e:
print(f"Error adding representative sequences: {e}")
def _add_sequences_horizontal(self, genotype_node: Dict, sequences: List[Dict], base_y: float, parent_x: float):
- """Add sequences horizontally with similarity-based line lengths"""
+ """Adds sequences with similarity-based line lengths."""
try:
- # Define the query line length as the reference (100%)
- query_line_length = 3.0 # Base length for query sequence (100%)
-
- # Separate query and matched sequences
- query_sequences = [seq for seq in sequences if seq['is_query']]
- matched_sequences = [seq for seq in sequences if seq['is_matched'] and not seq['is_query']]
-
+ query_line_length = 3.0
+ query_sequences = [s for s in sequences if s['is_query']]
+ matched_sequences = [s for s in sequences if s['is_matched'] and not s['is_query']]
all_special_sequences = query_sequences + matched_sequences
-
if len(all_special_sequences) == 1:
- # Single sequence - direct line with similarity-based length
sequence = all_special_sequences[0]
line_length = self._calculate_similarity_based_line_length(sequence, query_line_length)
x_pos = parent_x + line_length
-
genotype_node['children'][sequence['id']] = {
- 'name': f"{sequence['id']}{' (' + str(sequence['similarity']) + '%)' if sequence['is_matched'] else ''}",
- 'type': 'sequence',
- 'data': sequence,
- 'x': x_pos,
- 'y': base_y,
- 'has_vertical_attachment': False,
- 'similarity_line_length': line_length
+ 'name': f"{sequence['id']} ({sequence['similarity']}%)" if sequence['is_matched'] else sequence['id'],
+ 'type': 'sequence', 'data': sequence, 'x': x_pos, 'y': base_y,
+ 'has_vertical_attachment': False, 'similarity_line_length': line_length
}
else:
- # Multiple sequences - equal vertical distribution with similarity-based horizontal lengths
sequence_positions = self._calculate_sequence_positions(all_special_sequences, base_y)
-
for seq_idx, sequence in enumerate(all_special_sequences):
line_length = self._calculate_similarity_based_line_length(sequence, query_line_length)
x_pos = parent_x + line_length
-
genotype_node['children'][sequence['id']] = {
- 'name': f"{sequence['id']}{' (' + str(sequence['similarity']) + '%)' if sequence['is_matched'] else ''}",
- 'type': 'sequence',
- 'data': sequence,
- 'x': x_pos,
- 'y': sequence_positions[seq_idx],
- 'has_vertical_attachment': False,
- 'similarity_line_length': line_length
+ 'name': f"{sequence['id']} ({sequence['similarity']}%)" if sequence['is_matched'] else sequence['id'],
+ 'type': 'sequence', 'data': sequence, 'x': x_pos, 'y': sequence_positions[seq_idx],
+ 'has_vertical_attachment': False, 'similarity_line_length': line_length
}
except Exception as e:
- print(f"Error adding sequences horizontally: {e}")
+ print(f"Error adding sequences: {e}")
- def _calculate_similarity_based_line_length(self, sequence: Dict, query_line_length: float) -> float:
- """Calculate line length based on similarity percentage relative to query"""
+ def _identify_query_ml_group(self, normalized_ml_groups: Dict):
+ """Identifies the ML group containing the query sequence."""
try:
- if sequence['is_query']:
- # Query sequence gets 100% length
- return query_line_length
- elif sequence['is_matched']:
- # Matched sequences get length proportional to their similarity
- similarity = sequence['similarity']
- # Convert similarity percentage to proportional length
- proportional_length = (similarity / 100.0) * query_line_length
- # Ensure minimum length for visibility
- min_length = query_line_length * 0.2 # Minimum 20% of query length
- return max(proportional_length, min_length)
- else:
- # Other sequences get a standard length (50% of query)
- return query_line_length * 0.5
+ for base_ml, ml_data in normalized_ml_groups.items():
+ if ml_data['has_special_sequences']:
+ for genotypes in ml_data['full_ml_groups'].values():
+ for sequences in genotypes.values():
+ if any(seq['is_query'] for seq in sequences):
+ self.query_ml_group = base_ml
+ return
except Exception as e:
- print(f"Error calculating similarity-based line length: {e}")
- return query_line_length * 0.5
+ print(f"Error identifying query ML group: {e}")
+
+ def _calculate_dynamic_ml_positions(self, normalized_ml_groups: Dict) -> List[float]:
+ """Calculates equal Y positions for ML groups."""
+ try:
+ ml_count = len(normalized_ml_groups)
+ if ml_count == 0:
+ return []
+ if ml_count == 1:
+ return [0.0]
+ total_spacing = (ml_count - 1) * 2.0
+ start_y = -total_spacing / 2
+ return [start_y + i * 2.0 for i in range(ml_count)]
+ except Exception as e:
+ print(f"Error calculating ML positions: {e}")
+ return list(range(len(normalized_ml_groups)))
+
+ def _calculate_full_ml_positions(self, full_ml_groups: Dict, base_y: float) -> List[float]:
+ """Calculates equal positions for full ML groups."""
+ try:
+ ml_count = len(full_ml_groups)
+ if ml_count <= 1:
+ return [base_y]
+ spacing = 1.5
+ start_y = base_y - (spacing * (ml_count - 1)) / 2
+ return [start_y + i * spacing for i in range(ml_count)]
+ except Exception as e:
+ print(f"Error calculating full ML positions: {e}")
+ return [base_y] * len(full_ml_groups)
+ def _calculate_genotype_positions(self, special_genotypes: List, base_y: float) -> List[float]:
+ """Calculates equal positions for genotypes."""
+ try:
+ genotype_count = len(special_genotypes)
+ if genotype_count <= 1:
+ return [base_y]
+ spacing = 1.0
+ start_y = base_y - (spacing * (genotype_count - 1)) / 2
+ return [start_y + i * spacing for i in range(genotype_count)]
+ except Exception as e:
+ print(f"Error calculating genotype positions: {e}")
+ return [base_y] * len(special_genotypes)
def _calculate_sequence_positions(self, sequences: List[Dict], base_y: float) -> List[float]:
- """Calculate equal positions for sequences"""
+ """Calculates equal positions for sequences."""
try:
seq_count = len(sequences)
if seq_count <= 1:
return [base_y]
-
- # Equal spacing for sequences
- spacing = 0.8 # Fixed spacing between sequences
+ spacing = 0.8
start_y = base_y - (spacing * (seq_count - 1)) / 2
-
- positions = []
- for i in range(seq_count):
- positions.append(start_y + i * spacing)
-
- return positions
-
+ return [start_y + i * spacing for i in range(seq_count)]
except Exception as e:
print(f"Error calculating sequence positions: {e}")
return [base_y] * len(sequences)
+ def _calculate_similarity_based_line_length(self, sequence: Dict, query_line_length: float) -> float:
+ """Calculates line length based on sequence similarity."""
+ try:
+ if sequence['is_query']:
+ return query_line_length
+ if sequence['is_matched']:
+ similarity = sequence['similarity']
+ proportional_length = (similarity / 100.0) * query_line_length
+ return max(proportional_length, query_line_length * 0.2)
+ return query_line_length * 0.5
+ except Exception as e:
+ print(f"Error calculating line length: {e}")
+ return query_line_length * 0.5
+
def _determine_horizontal_line_length(self, node_type: str, has_vertical: bool, contains_query: bool = False) -> float:
- """Determine horizontal line length based on node type and connections"""
+ """Determines horizontal line length based on node type."""
try:
base_length = self.base_horizontal_length
-
- # Special case: ML group containing query sequence gets much longer line
if contains_query and node_type == 'normalized_ml_group':
- return base_length * 2.5 # Much longer for query ML group
-
- # If this node has a vertical line attachment (connects to multiple children)
+ return base_length * 2.5
if has_vertical:
- # Find the current longest horizontal line with vertical
current_max = base_length
- for tracked_length in self.horizontal_line_tracker:
- if tracked_length > current_max:
- current_max = tracked_length
-
- # Make this line incrementally longer
+ for length in self.horizontal_line_tracker:
+ if length > current_max:
+ current_max = length
new_length = current_max + 0.3
self.horizontal_line_tracker.append(new_length)
return new_length
- else:
- # Direct connection (no vertical), use base length
- return base_length
-
+ return base_length
except Exception as e:
- print(f"Error determining horizontal line length: {e}")
+ print(f"Error determining line length: {e}")
return self.base_horizontal_length
- def _identify_query_ml_group(self, normalized_ml_groups: Dict):
- """Identify which ML group contains the query sequence"""
- try:
- for base_ml, ml_data in normalized_ml_groups.items():
- if ml_data['has_special_sequences']:
- for full_ml_name, genotypes in ml_data['full_ml_groups'].items():
- for genotype, sequences in genotypes.items():
- if any(seq['is_query'] for seq in sequences):
- self.query_ml_group = base_ml
- return
- except Exception as e:
- print(f"Error identifying query ML group: {e}")
-
- def _identify_query_ml_group(self, normalized_ml_groups: Dict):
- """Identify which ML group contains the query sequence"""
+ def _determine_genotype_horizontal_line_length(self, sequence_count: int, has_vertical: bool, contains_query: bool = False) -> float:
+ """Determines horizontal line length for genotype nodes."""
try:
- for base_ml, ml_data in normalized_ml_groups.items():
- if ml_data['has_special_sequences']:
- for full_ml_name, genotypes in ml_data['full_ml_groups'].items():
- for genotype, sequences in genotypes.items():
- if any(seq['is_query'] for seq in sequences):
- self.query_ml_group = base_ml
- return
+ base_length = self.base_horizontal_length
+ query_bonus = 0.5 if contains_query else 0.0
+ if sequence_count <= 1:
+ length_multiplier = 1.0
+ elif sequence_count <= 3:
+ length_multiplier = 1.6
+ elif sequence_count <= 5:
+ length_multiplier = 2.3
+ else:
+ length_multiplier = 6.0
+ return base_length * length_multiplier + query_bonus
except Exception as e:
- print(f"Error identifying query ML group: {e}")
-
- def _calculate_sequence_x_position_horizontal(self, sequence: Dict, max_similarity: float) -> float:
- """Calculate X position based on similarity percentage for horizontal layout"""
- # This function is now replaced by _calculate_similarity_based_line_length
- # Keeping for backward compatibility, but the new approach is used in _add_sequences_horizontal
-
- base_x = 0 # Relative to parent genotype node
- query_line_length = 3.0 # Reference length for query (100%)
-
- if sequence['is_query']:
- return base_x + query_line_length # 100% length for query
- elif sequence['is_matched']:
- # Line length varies based on similarity percentage
- similarity = sequence['similarity']
- proportional_length = (similarity / 100.0) * query_line_length
- min_length = query_line_length * 0.2 # Minimum 20% of query length
- return base_x + max(proportional_length, min_length)
- else:
- return base_x + (query_line_length * 0.5) # 50% length for other sequences
-
+ print(f"Error determining genotype line length: {e}")
+ return self.base_horizontal_length
- def create_interactive_tree(self, matched_ids: List[str], actual_percentage: float):
+ # --- Visualization ---
+ def create_interactive_tree(self, matched_ids: List[str], actual_percentage: float) -> Optional[go.Figure]:
+ """Creates an interactive horizontal phylogenetic tree visualization."""
try:
- print("🎨 Creating horizontal interactive tree visualization...")
-
- # Prepare data for plotting
- edge_x = []
- edge_y = []
- node_x = []
- node_y = []
- node_colors = []
- node_text = []
- node_hover = []
- node_sizes = []
-
- # Updated color scheme for new node types
+ print("🎨 Creating interactive tree visualization...")
+ edge_x, edge_y = [], []
+ node_x, node_y = [], []
+ node_colors, node_text, node_hover, node_sizes = [], [], [], []
colors = {
- 'root': '#FF0000', # Red for root
- 'normalized_ml_group': '#FFB6C1', # Light pink for normalized ML groups
- 'full_ml_group': '#FF69B4', # Hot pink for full ML groups
- 'genotype': '#FFD700', # Gold for genotypes
- 'representative_sequence': '#FFA500', # Orange for representative sequences
- 'query_sequence': '#4B0082', # Dark purple for query
- 'matched_sequence': '#6A5ACD', # Slate blue for matched
- 'other_sequence': '#87CEEB' # Sky blue for others
+ 'root': '#FF0000', 'normalized_ml_group': '#FFB6C1', 'full_ml_group': '#FF69B4',
+ 'genotype': '#FFD700', 'representative_sequence': '#FFA500', 'query_sequence': '#4B0082',
+ 'matched_sequence': '#6A5ACD', 'other_sequence': '#87CEEB'
}
def add_horizontal_edges(parent_x, parent_y, children_dict):
- """Add horizontal connecting lines with proper vertical line sizing"""
if not children_dict:
return
-
children_list = list(children_dict.values())
-
if len(children_list) == 1:
- # Single child - direct horizontal line
child = children_list[0]
edge_x.extend([parent_x, child['x'], None])
edge_y.extend([parent_y, child['y'], None])
else:
- # Multiple children - horizontal line with vertical distribution
- # Calculate the intermediate x position (where vertical line will be)
child_x_positions = [child['x'] for child in children_list]
min_child_x = min(child_x_positions)
- intermediate_x = parent_x + (min_child_x - parent_x) * 0.8 # 80% of the way to nearest child
-
- # Horizontal line to intermediate point
+ intermediate_x = parent_x + (min_child_x - parent_x) * 0.8
edge_x.extend([parent_x, intermediate_x, None])
edge_y.extend([parent_y, parent_y, None])
-
- # Calculate vertical line range to fit exactly all children
child_y_positions = [child['y'] for child in children_list]
min_y, max_y = min(child_y_positions), max(child_y_positions)
-
- # Vertical line sized exactly to fit all children
edge_x.extend([intermediate_x, intermediate_x, None])
edge_y.extend([min_y, max_y, None])
-
- # Horizontal lines from vertical line to each child
for child in children_list:
edge_x.extend([intermediate_x, child['x'], None])
edge_y.extend([child['y'], child['y'], None])
def get_node_color_and_size(node):
- """Determine node color and size based on type and content"""
if node['type'] == 'sequence':
if node['data']['is_query']:
- return colors['query_sequence'], 10 # Reduced size for compactness
- elif node['data']['is_matched']:
+ return colors['query_sequence'], 10
+ if node['data']['is_matched']:
return colors['matched_sequence'], 8
- else:
- return colors['other_sequence'], 6
- elif node['type'] == 'representative_sequence':
+ return colors['other_sequence'], 6
+ if node['type'] == 'representative_sequence':
return colors['representative_sequence'], 7
- elif node['type'] == 'normalized_ml_group':
- # Larger size if it has special sequences
- size = 9 if node.get('has_special_sequences', False) else 7
- return colors['normalized_ml_group'], size
- elif node['type'] == 'full_ml_group':
+ if node['type'] == 'normalized_ml_group':
+ return colors['normalized_ml_group'], 9 if node.get('has_special_sequences', False) else 7
+ if node['type'] == 'full_ml_group':
return colors['full_ml_group'], 8
- elif node['type'] == 'genotype':
+ if node['type'] == 'genotype':
return colors['genotype'], 7
- else:
- return colors.get(node['type'], '#000000'), 7
+ return colors.get(node['type'], '#000000'), 7
def create_node_text(node):
- """Create appropriate text label for each node type"""
if node['type'] == 'sequence':
- if node['data']['is_matched'] and not node['data']['is_query']:
- return f"{node['name']}"
- else:
- return node['name']
- elif node['type'] == 'representative_sequence':
- return node['name']
- elif node['type'] == 'normalized_ml_group':
- # Add indicator if it has special sequences
- suffix = " *" if node.get('has_special_sequences', False) else ""
- return f"{node['name']}{suffix}"
- else:
+ return f"{node['name']}" if node['data']['is_matched'] and not node['data']['is_query'] else node['name']
+ if node['type'] == 'representative_sequence':
return node['name']
+ if node['type'] == 'normalized_ml_group':
+ return f"{node['name']} *" if node.get('has_special_sequences', False) else node['name']
+ return node['name']
def create_hover_text(node):
- """Create detailed hover text for each node type"""
if node['type'] == 'sequence':
data = node['data']['data']
hover_text = (
- f"{node['name']}
"
- f"Type: {'Query Sequence' if node['data']['is_query'] else 'Matched Sequence' if node['data']['is_matched'] else 'Other Sequence'}
"
- f"ML Group: {data.get('ML', 'N/A')}
"
- f"Genotype: {data.get('Genotype', 'N/A')}
"
- f"Host: {data.get('Host', 'N/A')}
"
- f"Country: {data.get('Country', 'N/A')}
"
- f"Isolate: {data.get('Isolate', 'N/A')}
"
- f"Year: {data.get('Year', 'N/A')}"
+ f"{node['name']}
Type: {'Query' if node['data']['is_query'] else 'Matched' if node['data']['is_matched'] else 'Other'} Sequence
"
+ f"ML Group: {data.get('ML', 'N/A')}
Genotype: {data.get('Genotype', 'N/A')}
"
+ f"Host: {data.get('Host', 'N/A')}
Country: {data.get('Country', 'N/A')}
"
+ f"Isolate: {data.get('Isolate', 'N/A')}
Year: {data.get('Year', 'N/A')}"
)
if node['data']['is_matched']:
hover_text += f"
Similarity: {node['data']['similarity']}%"
elif node['type'] == 'representative_sequence':
data = node['data']['data']
hover_text = (
- f"{node['name']}
"
- f"Type: Representative Sequence
"
- f"ML Group: {data.get('ML', 'N/A')}
"
- f"Genotype: {data.get('Genotype', 'N/A')}
"
- f"Host: {data.get('Host', 'N/A')}
"
- f"Country: {data.get('Country', 'N/A')}"
+ f"{node['name']}
Type: Representative Sequence
"
+ f"ML Group: {data.get('ML', 'N/A')}
Genotype: {data.get('Genotype', 'N/A')}
"
+ f"Host: {data.get('Host', 'N/A')}
Country: {data.get('Country', 'N/A')}"
)
elif node['type'] == 'normalized_ml_group':
hover_text = f"{node['name']}
Type: Normalized ML Group"
@@ -1030,146 +695,61 @@ class PhylogeneticTreeAnalyzer:
hover_text += f"
Special Sequences: {special_count}/{len(node['sequences'])}"
else:
hover_text = f"{node['name']}
Type: {node['type'].replace('_', ' ').title()}"
-
return hover_text
def add_node_and_edges(node, parent_x=None, parent_y=None):
- """Recursively add nodes and edges to the plot with equal spacing structure."""
x, y = node['x'], node['y']
node_x.append(x)
node_y.append(y)
-
- # Get node color and size
color, size = get_node_color_and_size(node)
node_colors.append(color)
node_sizes.append(size)
-
- # Create node text and hover
node_text.append(create_node_text(node))
node_hover.append(create_hover_text(node))
-
- # Process children with equal spacing structure
if 'children' in node and node['children']:
add_horizontal_edges(x, y, node['children'])
for child in node['children'].values():
add_node_and_edges(child, x, y)
- # Build the plot data starting from root
root_node = self.tree_structure['root']
add_node_and_edges(root_node)
-
- # Add horizontal edges for root level
if root_node['children']:
add_horizontal_edges(root_node['x'], root_node['y'], root_node['children'])
- # Create the figure
fig = go.Figure()
-
- # Add edges
fig.add_trace(go.Scatter(
- x=edge_x, y=edge_y,
- mode='lines',
- line=dict(width=1, color='gray', dash='solid'), # Thinner lines for compactness
- hoverinfo='none',
- showlegend=False,
- name='Edges'
+ x=edge_x, y=edge_y, mode='lines', line=dict(width=1, color='gray'),
+ hoverinfo='none', showlegend=False
))
-
- # Add nodes
fig.add_trace(go.Scatter(
- x=node_x, y=node_y,
- mode='markers+text',
- marker=dict(
- size=node_sizes,
- color=node_colors,
- line=dict(width=1, color='black'), # Thinner borders
- opacity=0.85
- ),
- text=node_text,
- textposition="middle right",
- textfont=dict(size=9, color="black"), # Smaller font for compactness
- hoverinfo='text',
- hovertext=node_hover,
- showlegend=False,
- name='Nodes'
+ x=node_x, y=node_y, mode='markers+text',
+ marker=dict(size=node_sizes, color=node_colors, line=dict(width=1, color='black'), opacity=0.85),
+ text=node_text, textposition="middle right", textfont=dict(size=9, color="black"),
+ hoverinfo='text', hovertext=node_hover, showlegend=False
))
- # Calculate proper layout dimensions to ensure everything fits
- if node_x and node_y:
- # Get the actual data bounds
- min_x, max_x = min(node_x), max(node_x)
- min_y, max_y = min(node_y), max(node_y)
+ min_x, max_x = min(node_x), max(node_x) if node_x else (0, 1)
+ min_y, max_y = min(node_y), max(node_y) if node_y else (0, 1)
+ x_range = max_x - min_x
+ y_range = max_y - min_y
+ x_padding = x_range * 0.2 if x_range > 0 else 1
+ y_padding = y_range * 0.2 if y_range > 0 else 1
+ width = min(1400, max(800, int(x_range * 80 + 400)))
+ height = min(900, max(500, int(y_range * 40 + 300)))
- # Calculate ranges
- x_range = max_x - min_x
- y_range = max_y - min_y
-
- # Add padding to ensure nothing is cut off (20% padding on each side)
- x_padding = x_range * 0.2 if x_range > 0 else 1
- y_padding = y_range * 0.2 if y_range > 0 else 1
-
- # Set axis ranges with padding
- x_axis_range = [min_x - x_padding, max_x + x_padding]
- y_axis_range = [min_y - y_padding, max_y + y_padding]
-
- # Compact but sufficient sizing
- width = min(1400, max(800, int(x_range * 80 + 400))) # Cap max width
- height = min(900, max(500, int(y_range * 40 + 300))) # Cap max height
- else:
- width, height = 800, 500
- x_axis_range = None
- y_axis_range = None
-
- # Update layout for compact horizontal tree with proper bounds
fig.update_layout(
title=dict(
- text=f"Compact Horizontal Phylogenetic Tree (ML-Based)
"
- f"Query: {self.query_id} | Similarity: {actual_percentage}% | "
- f"Matched: {len(matched_ids)}",
- x=0.5,
- font=dict(size=12) # Smaller title for compactness
- ),
- xaxis=dict(
- showgrid=False,
- gridcolor='lightgray',
- gridwidth=0.3, # Very thin grid lines
- zeroline=False,
- showticklabels=False,
- range=x_axis_range, # Set explicit range to prevent cutoff
- fixedrange=False, # Allow zooming if needed
- automargin=True # Automatically adjust margins
- ),
- yaxis=dict(
- showgrid=False,
- gridcolor='lightgray',
- gridwidth=0.3, # Very thin grid lines
- zeroline=False,
- showticklabels=False,
- range=y_axis_range, # Set explicit range to prevent cutoff
- fixedrange=False, # Allow zooming if needed
- automargin=True # Automatically adjust margins
+ text=f"Horizontal Phylogenetic Tree
Query: {self.query_id} | Similarity: {actual_percentage}% | Matched: {len(matched_ids)}",
+ x=0.5, font=dict(size=12)
),
- plot_bgcolor="white",
- paper_bgcolor="white",
- hovermode="closest",
- width=width,
- height=height,
- margin=dict(l=20, r=100, t=40, b=10), # Adequate margins, extra right margin for text
- autosize=False, # Don't auto-resize
- showlegend=True,
- legend=dict(
- x=1.02, # Position legend outside plot area
- y=1,
- xanchor='left',
- yanchor='top',
- bgcolor='rgba(255,255,255,0.8)',
- bordercolor='gray',
- borderwidth=1,
- font=dict(size=10) # Smaller legend font
- )
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[min_x - x_padding, max_x + x_padding], automargin=True),
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[min_y - y_padding, max_y + y_padding], automargin=True),
+ plot_bgcolor="white", paper_bgcolor="white", hovermode="closest",
+ width=width, height=height, margin=dict(l=20, r=100, t=40, b=10),
+ showlegend=True, legend=dict(x=1.02, y=1, xanchor='left', yanchor='top',
+ bgcolor='rgba(255,255,255,0.8)', bordercolor='gray', borderwidth=1, font=dict(size=10))
)
- # Add comprehensive legend with smaller markers
legend_elements = [
dict(name="Root", marker=dict(color=colors['root'], size=8)),
dict(name="Normalized ML Groups", marker=dict(color=colors['normalized_ml_group'], size=8)),
@@ -1178,52 +758,83 @@ class PhylogeneticTreeAnalyzer:
dict(name="Query Sequence", marker=dict(color=colors['query_sequence'], size=10)),
dict(name="Similar Sequences", marker=dict(color=colors['matched_sequence'], size=9)),
dict(name="Representative Sequences", marker=dict(color=colors['representative_sequence'], size=8)),
- dict(name="Other Sequences", marker=dict(color=colors['other_sequence'], size=7))
]
+ for element in legend_elements:
+ fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=element['marker'], name=element['name'], showlegend=True))
- for i, element in enumerate(legend_elements):
- fig.add_trace(go.Scatter(
- x=[None], y=[None],
- mode='markers',
- marker=element['marker'],
- name=element['name'],
- showlegend=True
- ))
-
-
- # Configure modebar for better user experience
config = {
- 'displayModeBar': True,
- 'displaylogo': False,
- 'modeBarButtonsToRemove': ['select2d', 'lasso2d'],
- 'toImageButtonOptions': {
- 'format': 'png',
- 'filename': 'phylogenetic_tree',
- 'height': height,
- 'width': width,
- 'scale': 2
- }
+ 'displayModeBar': True, 'displaylogo': False, 'modeBarButtonsToRemove': ['select2d', 'lasso2d'],
+ 'toImageButtonOptions': {'format': 'png', 'filename': 'phylogenetic_tree', 'height': height, 'width': width, 'scale': 2}
}
-
- # Save outputs
- # try:
- # fig.write_html("phylogenetic_tree_normalized_horizontal.html", config=config)
- # print("✓ Compact horizontal interactive tree saved as 'phylogenetic_tree_normalized_horizontal.html'")
- # except Exception as e:
- # print(f"Warning: Could not save HTML file: {e}")
-
- # Display the figure with config
try:
- fig.show(config=config)
+ fig.show(config)
except Exception as e:
print(f"Warning: Could not display figure: {e}")
-
return fig
-
except Exception as e:
- print(f"Error creating compact horizontal interactive tree: {e}")
+ print(f"Error creating tree visualization: {e}")
return None
+ # --- ML Analysis ---
+ def perform_ml_analysis_safe(self, matched_ids: List[str]) -> Dict:
+
+ try:
+ print("\n🧬 PERFORMING MAXIMUM LIKELIHOOD ANALYSIS")
+ print("="*50)
+
+ # Include query sequence in analysis
+ all_sequences = [self.query_id] + [seq_id for seq_id in matched_ids if seq_id != self.query_id]
+
+ # Limit number of sequences to prevent memory issues
+ if len(all_sequences) > 20:
+ print(f"Warning: Limiting analysis to 20 sequences (had {len(all_sequences)})")
+ all_sequences = all_sequences[:20]
+
+ if len(all_sequences) < 3:
+ print("❌ Need at least 3 sequences for ML analysis")
+ return {}
+
+ # Step 1: Create multiple sequence alignment
+ alignment = self.create_sequence_alignment(all_sequences)
+ if not alignment:
+ return {}
+
+ # Step 2: Calculate ML distances
+ distance_matrix = self.calculate_ml_distances(alignment)
+ if distance_matrix.size == 0:
+ return {}
+
+ # Step 3: Construct ML tree
+ ml_tree = self.construct_ml_tree(alignment)
+ if not ml_tree:
+ return {}
+
+ # Step 4: Calculate tree likelihood (safely)
+ log_likelihood = self.calculate_ml_likelihood_safe(ml_tree, alignment)
+
+ # Step 5: Prepare results
+ ml_results = {
+ 'tree': ml_tree,
+ 'alignment': alignment,
+ 'distance_matrix': distance_matrix,
+ 'log_likelihood': log_likelihood,
+ 'sequence_count': len(all_sequences),
+ 'alignment_length': len(alignment[0]) if alignment else 0
+ }
+
+ print(f"✅ ML analysis completed successfully")
+ print(f" Sequences analyzed: {len(all_sequences)}")
+ print(f" Alignment length: {ml_results['alignment_length']}")
+ print(f" Log-likelihood: {log_likelihood:.2f}")
+
+ return ml_results
+
+ except Exception as e:
+ print(f"❌ ML analysis failed: {e}")
+ import traceback
+ traceback.print_exc()
+ return {}
+
def create_sequence_alignment(self, sequence_ids: List[str]) -> Optional[MultipleSeqAlignment]:
@@ -1287,7 +898,7 @@ class PhylogeneticTreeAnalyzer:
return aligned_sequences
except Exception as e:
print(f"Error in simple alignment: {e}")
- return sequences # Return original sequences as fallback
+ return sequences
def calculate_ml_distances(self, alignment: MultipleSeqAlignment) -> np.ndarray:
@@ -1327,7 +938,6 @@ class PhylogeneticTreeAnalyzer:
def _alignment_to_matrix(self, alignment: MultipleSeqAlignment) -> np.ndarray:
try:
- # Nucleotide to number mapping
nucleotide_map = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4, '-': 5}
matrix = []
@@ -1341,6 +951,7 @@ class PhylogeneticTreeAnalyzer:
print(f"Error converting alignment to matrix: {e}")
return np.array([])
+
def _calculate_ml_distance_pair(self, seq1: np.ndarray, seq2: np.ndarray) -> float:
try:
@@ -1374,639 +985,593 @@ class PhylogeneticTreeAnalyzer:
return 1.0 # Return maximum distance if log calculation fails
except Exception as e:
- return 1.0 # Return maximum distance on error
-
- def construct_ml_tree(self, alignment: MultipleSeqAlignment) -> Optional[BaseTree.Tree]:
+ return 1.0
+ def construct_ml_tree(self, alignment: MultipleSeqAlignment) -> Optional[Tree]:
+ """Constructs a maximum likelihood tree."""
try:
- print("🌳 Constructing Maximum Likelihood tree...")
-
- # Calculate ML distance matrix
+ print("🌳 Constructing ML tree...")
distance_matrix = self.calculate_ml_distances(alignment)
-
if distance_matrix.size == 0:
return None
-
- # Create sequence names list
sequence_names = [record.id for record in alignment]
-
- # Build tree using neighbor-joining on ML distances
tree = self._build_nj_tree_from_distances(distance_matrix, sequence_names)
-
- # Optimize branch lengths using ML (with recursion protection)
if tree:
tree = self._optimize_branch_lengths_ml_safe(tree, alignment)
-
- print("✓ ML tree constructed successfully")
+ print("✓ ML tree constructed")
return tree
-
except Exception as e:
print(f"Error constructing ML tree: {e}")
return None
- def _build_nj_tree_from_distances(self, distance_matrix: np.ndarray, sequence_names: List[str]) -> Optional[BaseTree.Tree]:
-
+ def _build_nj_tree_from_distances(self, distance_matrix: np.ndarray, sequence_names: List[str]) -> Optional[Tree]:
+ """Builds a neighbor-joining tree from distance matrix."""
try:
- from Bio.Phylo.TreeConstruction import DistanceMatrix, DistanceTreeConstructor
-
- # Validate inputs
if distance_matrix.shape[0] != len(sequence_names):
- print("Error: Distance matrix size doesn't match sequence names")
+ print("Error: Distance matrix size mismatch")
return None
-
- # Convert numpy array to Bio.Phylo distance matrix format
- matrix_data = []
- for i in range(len(sequence_names)):
- row = []
- for j in range(i + 1):
- if i == j:
- row.append(0.0)
- else:
- # Ensure distance is valid
- dist = float(distance_matrix[i][j])
- if np.isnan(dist) or np.isinf(dist):
- dist = 1.0
- row.append(max(0.0, dist)) # Ensure non-negative
- matrix_data.append(row)
-
- # Create DistanceMatrix object
+ matrix_data = [[0.0 if i == j else max(0.0, float(distance_matrix[i][j])) for j in range(i + 1)] for i in range(len(sequence_names))]
dm = DistanceMatrix(names=sequence_names, matrix=matrix_data)
-
- # Build tree using Neighbor-Joining
constructor = DistanceTreeConstructor()
tree = constructor.nj(dm)
-
- # Validate tree structure
- if tree and self._validate_tree_structure(tree):
- return tree
- else:
- print("Warning: Tree structure validation failed")
- return tree # Return anyway, might still be usable
-
+ return tree if self._validate_tree_structure(tree) else None
except Exception as e:
print(f"Error building NJ tree: {e}")
return None
- def _validate_tree_structure(self, tree: BaseTree.Tree, max_depth: int = 100) -> bool:
-
+ def _validate_tree_structure(self, tree: Tree, max_depth: int = 100) -> bool:
+ """Validates tree structure to prevent recursion issues."""
try:
visited = set()
-
def check_node(node, depth=0):
if depth > max_depth:
return False
-
- # Check for circular references
node_id = id(node)
if node_id in visited:
return False
visited.add(node_id)
-
- # Check children
- for child in getattr(node, 'clades', []):
- if not check_node(child, depth + 1):
- return False
-
- return True
-
+ return all(check_node(child, depth + 1) for child in getattr(node, 'clades', []))
return check_node(tree.root if hasattr(tree, 'root') else tree)
except Exception:
return False
- def _optimize_branch_lengths_ml_safe(self, tree: BaseTree.Tree, alignment: MultipleSeqAlignment) -> BaseTree.Tree:
-
+ def _optimize_branch_lengths_ml_safe(self, tree: Tree, alignment: MultipleSeqAlignment) -> Tree:
+ """Optimizes branch lengths using ML model."""
try:
- print("🔧 Optimizing branch lengths with ML...")
-
- # Set recursion limit temporarily
+ print("🔧 Optimizing branch lengths...")
old_limit = sys.getrecursionlimit()
sys.setrecursionlimit(1000)
-
try:
- # Convert alignment to matrix
seq_matrix = self._alignment_to_matrix(alignment)
-
if seq_matrix.size == 0:
- print("Warning: Empty sequence matrix, skipping optimization")
return tree
-
- # Get all internal and external nodes with depth tracking
all_clades = self._get_clades_safe(tree)
-
- # Simple branch length optimization
for clade in all_clades:
if hasattr(clade, 'branch_length') and clade.branch_length is not None:
- try:
- # Calculate optimal branch length based on likelihood
- optimal_length = self._calculate_optimal_branch_length_safe(clade, seq_matrix)
- clade.branch_length = max(optimal_length, 0.001) # Minimum branch length
- except Exception as e:
- print(f"Warning: Failed to optimize branch for clade: {e}")
- # Keep original branch length
- pass
-
- print("✓ Branch lengths optimized")
-
+ optimal_length = self._calculate_optimal_branch_length(clade, seq_matrix)
+ clade.branch_length = max(optimal_length, 0.001)
finally:
- # Restore original recursion limit
sys.setrecursionlimit(old_limit)
-
+ print("✓ Branch lengths optimized")
return tree
-
except Exception as e:
- print(f"Warning: Branch length optimization failed: {e}")
+ print(f"Warning: Branch optimization failed: {e}")
return tree
- def _get_clades_safe(self, tree: BaseTree.Tree, max_depth: int = 50) -> List:
-
+ def _get_clades_safe(self, tree: Tree, max_depth: int = 50) -> List:
+ """Safely retrieves all clades in the tree."""
clades = []
visited = set()
-
def traverse_node(node, depth=0):
if depth > max_depth or id(node) in visited:
return
-
visited.add(id(node))
clades.append(node)
-
- # Traverse children safely
- try:
- children = getattr(node, 'clades', [])
- for child in children:
- traverse_node(child, depth + 1)
- except Exception:
- pass # Skip problematic nodes
-
+ for child in getattr(node, 'clades', []):
+ traverse_node(child, depth + 1)
try:
- root = tree.root if hasattr(tree, 'root') else tree
- traverse_node(root)
+ traverse_node(tree.root if hasattr(tree, 'root') else tree)
except Exception as e:
print(f"Warning: Tree traversal error: {e}")
-
return clades
- def _calculate_optimal_branch_length_safe(self, clade, seq_matrix: np.ndarray) -> float:
-
+ def _calculate_optimal_branch_length(self, clade: float, seq_matrix: np.ndarray) -> float:
+ """Calculates optimal branch length for a clade."""
try:
- # Simplified ML branch length estimation
if not hasattr(clade, 'branch_length') or clade.branch_length is None:
- return 0.1 # Default branch length
-
+ return 0.1
current_length = float(clade.branch_length)
-
- # Validate current length
if np.isnan(current_length) or np.isinf(current_length) or current_length <= 0:
return 0.1
-
- # Simple optimization based on sequence characteristics
- if hasattr(clade, 'name') and clade.name:
- # For terminal nodes
- return min(max(current_length * 0.9, 0.001), 1.0)
- else:
- # For internal nodes
- return min(max(current_length * 1.1, 0.001), 1.0)
-
+ return min(max(current_length * (0.9 if hasattr(clade, 'name') and clade.name else 1.1), 0.001), 1.0)
except Exception:
- return 0.1 # Safe default
-
- def calculate_ml_likelihood_safe(self, tree: BaseTree.Tree, alignment: MultipleSeqAlignment) -> float:
+ return 0.1
+ def calculate_ml_likelihood_safe(self, tree: Tree, alignment: MultipleSeqAlignment) -> float:
+ """Calculates tree likelihood using Jukes-Cantor model."""
try:
- print("📈 Calculating tree likelihood...")
-
+ print("Trying to calculate tree likelihood...")
seq_matrix = self._alignment_to_matrix(alignment)
-
if seq_matrix.size == 0:
return -np.inf
-
- # Simplified likelihood calculation using Jukes-Cantor model
total_log_likelihood = 0.0
-
- # For each site in the alignment (sample subset to avoid memory issues)
- n_sites = min(seq_matrix.shape[1], 1000) # Limit sites for performance
-
- for site in range(0, n_sites, max(1, n_sites // 100)): # Sample sites
- try:
- site_pattern = seq_matrix[:, site]
-
- # Skip sites with gaps or N's
- valid_positions = site_pattern < 4
- if np.sum(valid_positions) < 2:
- continue
-
- # Calculate likelihood for this site pattern
- site_likelihood = self._calculate_site_likelihood_safe(tree, site_pattern)
-
- if site_likelihood > 0:
- total_log_likelihood += np.log(site_likelihood)
-
- except Exception as e:
- print(f"Warning: Error processing site {site}: {e}")
+ n_sites = min(seq_matrix.shape[1], 1000)
+ for site in range(0, n_sites, max(1, n_sites // 100)):
+ site_pattern = seq_matrix[:, site]
+ valid_positions = site_pattern < 4
+ if np.sum(valid_positions) < 2:
continue
-
- print(f"✓ Tree likelihood calculated: {total_log_likelihood:.2f}")
+ site_likelihood = self._calculate_site_likelihood_safe(tree, site_pattern)
+ if site_likelihood > 0:
+ total_log_likelihood += np.log(site_likelihood)
+ print(f"Likelihood: {total_log_likelihood:.2f}")
return total_log_likelihood
-
except Exception as e:
print(f"Error calculating likelihood: {e}")
return -np.inf
- def _calculate_site_likelihood_safe(self, tree: BaseTree.Tree, site_pattern: np.ndarray) -> float:
-
+ def _calculate_site_likelihood_safe(self, tree: np.ndarray, site_pattern: np.ndarray) -> float:
+ """Calculates likelihood for a single site."""
try:
- # Count nucleotide frequencies at this site
valid_nucs = site_pattern[site_pattern < 4]
-
if len(valid_nucs) == 0:
return 1.0
-
- # Simple likelihood based on nucleotide diversity
unique_nucs = len(np.unique(valid_nucs))
total_nucs = len(valid_nucs)
-
- # Higher diversity = lower likelihood of simple evolution
- diversity_factor = unique_nucs / 4.0 # Normalize by 4 nucleotides
-
- # Simple likelihood model
+ diversity_factor = unique_nucs / 4.0
likelihood = np.exp(-diversity_factor * total_nucs * 0.1)
-
- return max(likelihood, 1e-10) # Avoid zero likelihood
-
+ return max(likelihood, 1e-10)
except Exception:
- return 1e-10 # Safe fallback
-
- def perform_ml_analysis_safe(self, matched_ids: List[str]) -> Dict:
+ return 1e-10
+ # --- Reporting ---
+ def generate_detailed_report(self, matched_ids: List[str], actual_percentage: float) -> bool:
+ """
+ Generate a detailed HTML report for virologists/scientists with query details, matched sequences,
+ model performance, phylogenetic tree insights, and ML analysis results in tabular format.
+ Outputs a styled HTML file.
+ Returns True if successful, False otherwise.
+ """
try:
- print("\n🧬 PERFORMING MAXIMUM LIKELIHOOD ANALYSIS")
- print("="*50)
-
- # Include query sequence in analysis
- all_sequences = [self.query_id] + [seq_id for seq_id in matched_ids if seq_id != self.query_id]
-
- # Limit number of sequences to prevent memory issues
- if len(all_sequences) > 20:
- print(f"Warning: Limiting analysis to 20 sequences (had {len(all_sequences)})")
- all_sequences = all_sequences[:20]
-
- if len(all_sequences) < 3:
- print("❌ Need at least 3 sequences for ML analysis")
- return {}
-
- # Step 1: Create multiple sequence alignment
- alignment = self.create_sequence_alignment(all_sequences)
- if not alignment:
- return {}
-
- # Step 2: Calculate ML distances
- distance_matrix = self.calculate_ml_distances(alignment)
- if distance_matrix.size == 0:
- return {}
-
- # Step 3: Construct ML tree
- ml_tree = self.construct_ml_tree(alignment)
- if not ml_tree:
- return {}
-
- # Step 4: Calculate tree likelihood (safely)
- log_likelihood = self.calculate_ml_likelihood_safe(ml_tree, alignment)
+ print("📝 Generating detailed HTML analysis report...")
+
+ # --- HTML Template with Inline CSS ---
+ html_content = """
+
+
+
Generated on: {timestamp}
+Query ID: {query_id}
+ """ + + # Add timestamp and query ID to HTML + timestamp = time.strftime("%Y-%m-%d %H:%M:%S %Z") + html_content = html_content.format(query_id=self.query_id, timestamp=timestamp) + + # --- Query Information --- + query_type = ( + "Accession Number" if self.query_id in self.data['Accession Number'].values else + "Dataset Sequence" if self.query_sequence in self.data['F-gene'].values else + "Novel Sequence" + ) + query_ml = "Unknown" + query_genotype = "Unknown" + query_metadata = {} + + if query_type == "Novel Sequence": + query_ml = self.predict_ml_group(self.query_sequence) + query_genotype = self.predict_genotype(self.query_sequence) + query_metadata = {"F-gene": self.query_sequence[:50] + "..." if len(self.query_sequence) > 50 else self.query_sequence} + else: + query_row = self.data[ + (self.data['Accession Number'] == self.query_id) | + (self.data['F-gene'] == re.sub(r'[^ATGC]', '', self.query_sequence.upper())) + ].iloc[0] + query_ml = query_row['ML'] + query_genotype = query_row['Genotype'] + query_metadata = query_row.to_dict() + query_metadata['F-gene'] = query_metadata['F-gene'][:50] + "..." if len(query_metadata['F-gene']) > 50 else query_metadata['F-gene'] + + query_info_table = [ + ["Query ID", self.query_id], + ["Query Type", query_type], + ["Sequence Length", f"{len(self.query_sequence)} nucleotides"], + ["ML Group", query_ml], + ["Genotype", query_genotype], + ["Target Similarity", f"{self.matching_percentage}%"], + ["Actual Similarity", f"{actual_percentage:.1f}%"] + ] - # Step 5: Prepare results - ml_results = { - 'tree': ml_tree, - 'alignment': alignment, - 'distance_matrix': distance_matrix, - 'log_likelihood': log_likelihood, - 'sequence_count': len(all_sequences), - 'alignment_length': len(alignment[0]) if alignment else 0 - } + # Add Query Information section + html_content += """ +| Field | Value |
|---|---|
| {row[0]} | {row[1]} |
Total Matched Sequences: {len(matched_ids)}
+ """ + if matched_sequences_table: + html_content += """ +| {header} | " + html_content += """ +
|---|
| {cell} | " + html_content += "
No matched sequences found.
+ """ + + # --- Model Performance --- + model_performance_table = [ + ["ML Model Accuracy", f"{self.ml_model_accuracy:.2%}" if self.ml_model_accuracy else "Not trained"], + ["Genotype Model Accuracy", f"{self.genotype_model_accuracy:.2%}" if self.genotype_model_accuracy else "Not trained"] + ] - print(f"✅ ML analysis completed successfully") - print(f" Sequences analyzed: {len(all_sequences)}") - print(f" Alignment length: {ml_results['alignment_length']}") - print(f" Log-likelihood: {log_likelihood:.2f}") + # Add Model Performance section + html_content += """ +| Metric | Value |
|---|---|
| {row[0]} | {row[1]} |
| Field | Value |
|---|---|
| {row[0]} | {row[1]} |
| Field | Value |
|---|---|
| {row[0]} | {row[1]} |