Spaces:
Configuration error
Configuration error
| # app.py | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import pandas as pd | |
| from Bio.Seq import Seq | |
| from Bio.SeqRecord import SeqRecord | |
| from Bio import SeqIO | |
| from Bio.Align import MultipleSeqAlignment | |
| from Bio import Phylo | |
| import numpy as np | |
| import copy | |
| import os | |
| import subprocess | |
| import logging | |
| from io import StringIO | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # FastAPI app | |
| app = FastAPI( | |
| title="Viral Sequence Analysis API", | |
| description="API for analyzing RNA sequences of viruses", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://localhost:3000"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Paths | |
| CSV_PATH = os.path.join(os.path.dirname(__file__), "f gene clean dataset.csv") | |
| FASTA_PATH = os.path.join(os.path.dirname(__file__), "selected_sequences.fasta") | |
| ALN_FILE = os.path.join(os.path.dirname(__file__), "alignment.phy") | |
| TREE_FILE = os.path.join(os.path.dirname(__file__), "alignment.phy_phyml_tree.txt") | |
| # Load dataset once at startup | |
| try: | |
| df = pd.read_csv(CSV_PATH) | |
| except FileNotFoundError as e: | |
| raise RuntimeError(f"Dataset not found: {str(e)}") | |
| # Analysis Functions | |
| def validate_fasta(sequence: str) -> bool: | |
| """ | |
| Validate if the input is a valid FASTA sequence. | |
| """ | |
| try: | |
| seq_io = StringIO(sequence) | |
| SeqIO.read(seq_io, "fasta") | |
| logger.info("FASTA sequence validated successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Invalid FASTA sequence: {str(e)}") | |
| return False | |
| def calculate_similarity(seq1: str, seq2: str) -> float: | |
| """ | |
| Calculate similarity between two sequences using identity matching. | |
| """ | |
| matches = sum(a == b for a, b in zip(seq1, seq2)) | |
| min_len = min(len(seq1), len(seq2)) | |
| return (matches / min_len) * 100 if min_len > 0 else 0 | |
| def find_similar_sequences(user_seq: str, df: pd.DataFrame, threshold: float) -> tuple: | |
| """ | |
| Find sequences in the dataset with similarity close to the threshold. | |
| """ | |
| similarities = [] | |
| for i, row in df.iterrows(): | |
| other_seq = row["F-gene"] | |
| if pd.notna(other_seq): | |
| sim = calculate_similarity(user_seq, other_seq) | |
| similarities.append((row["Accession Number"], other_seq, sim)) | |
| if not similarities: | |
| raise ValueError("No valid sequences found in dataset") | |
| percentages = [round(sim[2]) for sim in similarities] | |
| available = list(set(percentages)) | |
| closest = min(available, key=lambda x: abs(x - threshold)) | |
| logger.info(f"Using closest match threshold: {closest}%") | |
| matched = [(acc, seq) for acc, seq, sim in similarities if round(sim) == closest] | |
| return matched, closest | |
| def write_fasta(df: pd.DataFrame, matched: list, user_seq: str, fasta_path: str = "selected_sequences.fasta", max_other: int = 2) -> str: | |
| """ | |
| Write matched sequences and query sequence to a FASTA file. | |
| """ | |
| query_written = False | |
| written_ids = set() | |
| query_id = None | |
| with open(fasta_path, "w") as f: | |
| for acc, seq in matched: | |
| f.write(f">{acc}\n{seq}\n") | |
| written_ids.add(acc) | |
| if seq == user_seq: | |
| query_id = acc | |
| query_written = True | |
| if not query_written: | |
| f.write(f">Query\n{user_seq}\n") | |
| query_written = True | |
| query_id = "Query" | |
| reps_written = 0 | |
| for _, row in df.iterrows(): | |
| if row["Accession Number"] not in written_ids and pd.notna(row["F-gene"]): | |
| f.write(f">{row['Accession Number']}\n{row['F-gene']}\n") | |
| reps_written += 1 | |
| written_ids.add(row["Accession Number"]) | |
| if reps_written >= max_other * 10: | |
| break | |
| if not query_written: | |
| raise RuntimeError("Query sequence wasn't included in the matched set") | |
| logger.info(f"FASTA file created: {fasta_path}") | |
| return query_id | |
| def align_sequences(fasta_path: str) -> MultipleSeqAlignment: | |
| """ | |
| Align sequences by padding to equal length. | |
| """ | |
| if not os.path.exists(fasta_path): | |
| raise FileNotFoundError(f"FASTA file not found: {fasta_path}") | |
| records = list(SeqIO.parse(fasta_path, "fasta")) | |
| max_len = max(len(record.seq) for record in records) | |
| aligned_records = [] | |
| for record in records: | |
| padded = str(record.seq).ljust(max_len, "-") | |
| aligned_records.append(SeqRecord(Seq(padded), id=record.id, description="")) | |
| alignment = MultipleSeqAlignment(aligned_records) | |
| logger.info("Sequences aligned") | |
| return alignment | |
| def build_ml_tree(alignment: MultipleSeqAlignment, aln_file: str = "alignment.phy", tree_file: str = "alignment.phy_phyml_tree.txt") -> Phylo.BaseTree.Tree: | |
| """ | |
| Build a maximum likelihood tree using PhyML. | |
| """ | |
| try: | |
| with open(aln_file, "w") as handle: | |
| SeqIO.write(alignment, handle, "phylip-relaxed") | |
| logger.info("Running PhyML for tree construction") | |
| PHYML_PATH = r"C:\Users\Dell\Documents\LV4 Term2\tree\tree-ml\phyml.exe" # Updated path | |
| if not os.path.exists(PHYML_PATH): | |
| raise FileNotFoundError(f"PhyML executable not found at: {PHYML_PATH}") | |
| result = subprocess.run([PHYML_PATH, "-i", aln_file, "-d", "nt", "-m", "GTR", "-b", "0"], capture_output=True, text=True, check=True) | |
| logger.info(f"PhyML output: {result.stdout}") | |
| logger.error(f"PhyML errors: {result.stderr}") | |
| if not os.path.exists(tree_file): | |
| raise RuntimeError("PhyML tree file not created") | |
| tree = Phylo.read(tree_file, "newick") | |
| logger.info("ML tree constructed") | |
| return tree | |
| except Exception as e: | |
| logger.error(f"PhyML tree construction failed: {str(e)}") | |
| raise RuntimeError(f"Tree construction failed: {str(e)}") | |
| def simplify_tree(tree: Phylo.BaseTree.Tree, query_id: str, matched_ids: list, max_children: int = 2) -> Phylo.BaseTree.Tree: | |
| """ | |
| Simplify the tree by limiting non-query branch children. | |
| """ | |
| new_tree = copy.deepcopy(tree) | |
| query_node = next(new_tree.find_clades(target=query_id)) | |
| for clade in new_tree.find_clades(): | |
| for child in clade.clades: | |
| child.parent = clade | |
| preserved_branch = set() | |
| current = query_node | |
| while current != new_tree.root: | |
| preserved_branch.add(current) | |
| current = current.parent | |
| preserved_branch.add(new_tree.root) | |
| for clade in list(new_tree.find_clades()): | |
| if clade in preserved_branch: | |
| continue | |
| if not clade.is_terminal(): | |
| terminals = clade.get_terminals() | |
| if len(terminals) > max_children: | |
| clade.clades = terminals[:max_children] | |
| logger.info("Tree simplified") | |
| return new_tree | |
| def tree_to_json(tree: Phylo.BaseTree.Tree, df: pd.DataFrame, query_id: str, matched_ids: list) -> dict: | |
| """ | |
| Convert the tree to a JSON format for Cytoscape.js visualization. | |
| """ | |
| nodes = [] | |
| edges = [] | |
| seen_labels = set() | |
| metadata_dict = df.set_index("Accession Number")[["Genotype", "Host", "Country", "Isolate", "Year"]].to_dict('index') | |
| def clean_label(label, fallback_id): | |
| if not label or not str(label).strip(): | |
| label = f"Unnamed_{fallback_id}" | |
| else: | |
| label = str(label).strip().replace("\n", " ").replace("\r", "") | |
| if label in seen_labels: | |
| label += f"_{fallback_id}" | |
| seen_labels.add(label) | |
| return label | |
| def traverse_tree(node, parent_id=None, node_id=0): | |
| label = clean_label(node.name, node_id) | |
| metadata = metadata_dict.get(label, {}) if label in metadata_dict else {} | |
| node_data = { | |
| "id": label, | |
| "kercode": label, | |
| "metadata": { | |
| "Genotype": metadata.get("Genotype", ""), | |
| "Host": metadata.get("Host", ""), | |
| "Country": metadata.get("Country", ""), | |
| "Isolate": metadata.get("Isolate", ""), | |
| "Year": metadata.get("Year", "") | |
| }, | |
| "is_query": label == query_id, | |
| "is_matched": label in matched_ids | |
| } | |
| nodes.append(node_data) | |
| if parent_id: | |
| edges.append({ | |
| "source": parent_id, | |
| "target": label, | |
| "distance": node.branch_length or 0.01 | |
| }) | |
| for i, child in enumerate(node.clades): | |
| traverse_tree(child, label, node_id + i + 1) | |
| traverse_tree(tree.root) | |
| logger.info("Tree converted to JSON") | |
| return {"nodes": nodes, "edges": edges} | |
| def get_virus_metadata(kercode: str, df: pd.DataFrame) -> dict: | |
| """ | |
| Retrieve metadata for a virus by its kercode (Accession Number). | |
| """ | |
| metadata_dict = df.set_index("Accession Number")[["Genotype", "Host", "Country", "Isolate", "Year"]].to_dict('index') | |
| metadata = metadata_dict.get(kercode) | |
| if metadata: | |
| logger.info(f"Retrieved metadata for {kercode}") | |
| return metadata | |
| else: | |
| logger.warning(f"No metadata found for {kercode}") | |
| return None | |
| # FastAPI Endpoints | |
| async def analyze_sequence(file: UploadFile = File(...), threshold: float = 90.0): | |
| """ | |
| Analyze an RNA sequence and return closest match and phylogenetic tree. | |
| """ | |
| if not file.filename.endswith((".fasta", ".fa")): | |
| raise HTTPException(status_code=400, detail="Invalid file format. Use FASTA.") | |
| try: | |
| sequence = await file.read() | |
| sequence = sequence.decode("utf-8") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Failed to read file: {str(e)}") | |
| if not validate_fasta(sequence): | |
| raise HTTPException(status_code=400, detail="Invalid FASTA sequence") | |
| try: | |
| seq_io = StringIO(sequence) | |
| user_seq = str(SeqIO.read(seq_io, "fasta").seq) | |
| matched, closest_threshold = find_similar_sequences(user_seq, df, threshold) | |
| matched_ids = [acc for acc, _ in matched] | |
| query_id = write_fasta(df, matched, user_seq, FASTA_PATH) | |
| alignment = align_sequences(FASTA_PATH) | |
| tree = build_ml_tree(alignment, ALN_FILE, TREE_FILE) | |
| simplified_tree = simplify_tree(tree, query_id, matched_ids) | |
| tree_json = tree_to_json(simplified_tree, df, query_id, matched_ids) | |
| closest_kercode = matched[0][0] if matched else query_id | |
| closest_metadata = get_virus_metadata(closest_kercode, df) | |
| return { | |
| "closest_match": { | |
| "kercode": closest_kercode, | |
| "similarity": closest_threshold, | |
| "metadata": closest_metadata | |
| }, | |
| "tree": tree_json | |
| } | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except RuntimeError as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") | |
| async def get_virus_metadata_endpoint(kercode: str): | |
| """ | |
| Retrieve metadata for a specific virus by kercode. | |
| """ | |
| metadata = get_virus_metadata(kercode, df) | |
| if not metadata: | |
| raise HTTPException(status_code=404, detail="Virus not found") | |
| return metadata |