hadir123's picture
Add AI model files and FastAPI backend
99965bb
raw
history blame
11.6 kB
# app.py
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import pandas as pd
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO
from Bio.Align import MultipleSeqAlignment
from Bio import Phylo
import numpy as np
import copy
import os
import subprocess
import logging
from io import StringIO
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI app
app = FastAPI(
title="Viral Sequence Analysis API",
description="API for analyzing RNA sequences of viruses",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Paths
CSV_PATH = os.path.join(os.path.dirname(__file__), "f gene clean dataset.csv")
FASTA_PATH = os.path.join(os.path.dirname(__file__), "selected_sequences.fasta")
ALN_FILE = os.path.join(os.path.dirname(__file__), "alignment.phy")
TREE_FILE = os.path.join(os.path.dirname(__file__), "alignment.phy_phyml_tree.txt")
# Load dataset once at startup
try:
df = pd.read_csv(CSV_PATH)
except FileNotFoundError as e:
raise RuntimeError(f"Dataset not found: {str(e)}")
# Analysis Functions
def validate_fasta(sequence: str) -> bool:
"""
Validate if the input is a valid FASTA sequence.
"""
try:
seq_io = StringIO(sequence)
SeqIO.read(seq_io, "fasta")
logger.info("FASTA sequence validated successfully")
return True
except Exception as e:
logger.error(f"Invalid FASTA sequence: {str(e)}")
return False
def calculate_similarity(seq1: str, seq2: str) -> float:
"""
Calculate similarity between two sequences using identity matching.
"""
matches = sum(a == b for a, b in zip(seq1, seq2))
min_len = min(len(seq1), len(seq2))
return (matches / min_len) * 100 if min_len > 0 else 0
def find_similar_sequences(user_seq: str, df: pd.DataFrame, threshold: float) -> tuple:
"""
Find sequences in the dataset with similarity close to the threshold.
"""
similarities = []
for i, row in df.iterrows():
other_seq = row["F-gene"]
if pd.notna(other_seq):
sim = calculate_similarity(user_seq, other_seq)
similarities.append((row["Accession Number"], other_seq, sim))
if not similarities:
raise ValueError("No valid sequences found in dataset")
percentages = [round(sim[2]) for sim in similarities]
available = list(set(percentages))
closest = min(available, key=lambda x: abs(x - threshold))
logger.info(f"Using closest match threshold: {closest}%")
matched = [(acc, seq) for acc, seq, sim in similarities if round(sim) == closest]
return matched, closest
def write_fasta(df: pd.DataFrame, matched: list, user_seq: str, fasta_path: str = "selected_sequences.fasta", max_other: int = 2) -> str:
"""
Write matched sequences and query sequence to a FASTA file.
"""
query_written = False
written_ids = set()
query_id = None
with open(fasta_path, "w") as f:
for acc, seq in matched:
f.write(f">{acc}\n{seq}\n")
written_ids.add(acc)
if seq == user_seq:
query_id = acc
query_written = True
if not query_written:
f.write(f">Query\n{user_seq}\n")
query_written = True
query_id = "Query"
reps_written = 0
for _, row in df.iterrows():
if row["Accession Number"] not in written_ids and pd.notna(row["F-gene"]):
f.write(f">{row['Accession Number']}\n{row['F-gene']}\n")
reps_written += 1
written_ids.add(row["Accession Number"])
if reps_written >= max_other * 10:
break
if not query_written:
raise RuntimeError("Query sequence wasn't included in the matched set")
logger.info(f"FASTA file created: {fasta_path}")
return query_id
def align_sequences(fasta_path: str) -> MultipleSeqAlignment:
"""
Align sequences by padding to equal length.
"""
if not os.path.exists(fasta_path):
raise FileNotFoundError(f"FASTA file not found: {fasta_path}")
records = list(SeqIO.parse(fasta_path, "fasta"))
max_len = max(len(record.seq) for record in records)
aligned_records = []
for record in records:
padded = str(record.seq).ljust(max_len, "-")
aligned_records.append(SeqRecord(Seq(padded), id=record.id, description=""))
alignment = MultipleSeqAlignment(aligned_records)
logger.info("Sequences aligned")
return alignment
def build_ml_tree(alignment: MultipleSeqAlignment, aln_file: str = "alignment.phy", tree_file: str = "alignment.phy_phyml_tree.txt") -> Phylo.BaseTree.Tree:
"""
Build a maximum likelihood tree using PhyML.
"""
try:
with open(aln_file, "w") as handle:
SeqIO.write(alignment, handle, "phylip-relaxed")
logger.info("Running PhyML for tree construction")
PHYML_PATH = r"C:\Users\Dell\Documents\LV4 Term2\tree\tree-ml\phyml.exe" # Updated path
if not os.path.exists(PHYML_PATH):
raise FileNotFoundError(f"PhyML executable not found at: {PHYML_PATH}")
result = subprocess.run([PHYML_PATH, "-i", aln_file, "-d", "nt", "-m", "GTR", "-b", "0"], capture_output=True, text=True, check=True)
logger.info(f"PhyML output: {result.stdout}")
logger.error(f"PhyML errors: {result.stderr}")
if not os.path.exists(tree_file):
raise RuntimeError("PhyML tree file not created")
tree = Phylo.read(tree_file, "newick")
logger.info("ML tree constructed")
return tree
except Exception as e:
logger.error(f"PhyML tree construction failed: {str(e)}")
raise RuntimeError(f"Tree construction failed: {str(e)}")
def simplify_tree(tree: Phylo.BaseTree.Tree, query_id: str, matched_ids: list, max_children: int = 2) -> Phylo.BaseTree.Tree:
"""
Simplify the tree by limiting non-query branch children.
"""
new_tree = copy.deepcopy(tree)
query_node = next(new_tree.find_clades(target=query_id))
for clade in new_tree.find_clades():
for child in clade.clades:
child.parent = clade
preserved_branch = set()
current = query_node
while current != new_tree.root:
preserved_branch.add(current)
current = current.parent
preserved_branch.add(new_tree.root)
for clade in list(new_tree.find_clades()):
if clade in preserved_branch:
continue
if not clade.is_terminal():
terminals = clade.get_terminals()
if len(terminals) > max_children:
clade.clades = terminals[:max_children]
logger.info("Tree simplified")
return new_tree
def tree_to_json(tree: Phylo.BaseTree.Tree, df: pd.DataFrame, query_id: str, matched_ids: list) -> dict:
"""
Convert the tree to a JSON format for Cytoscape.js visualization.
"""
nodes = []
edges = []
seen_labels = set()
metadata_dict = df.set_index("Accession Number")[["Genotype", "Host", "Country", "Isolate", "Year"]].to_dict('index')
def clean_label(label, fallback_id):
if not label or not str(label).strip():
label = f"Unnamed_{fallback_id}"
else:
label = str(label).strip().replace("\n", " ").replace("\r", "")
if label in seen_labels:
label += f"_{fallback_id}"
seen_labels.add(label)
return label
def traverse_tree(node, parent_id=None, node_id=0):
label = clean_label(node.name, node_id)
metadata = metadata_dict.get(label, {}) if label in metadata_dict else {}
node_data = {
"id": label,
"kercode": label,
"metadata": {
"Genotype": metadata.get("Genotype", ""),
"Host": metadata.get("Host", ""),
"Country": metadata.get("Country", ""),
"Isolate": metadata.get("Isolate", ""),
"Year": metadata.get("Year", "")
},
"is_query": label == query_id,
"is_matched": label in matched_ids
}
nodes.append(node_data)
if parent_id:
edges.append({
"source": parent_id,
"target": label,
"distance": node.branch_length or 0.01
})
for i, child in enumerate(node.clades):
traverse_tree(child, label, node_id + i + 1)
traverse_tree(tree.root)
logger.info("Tree converted to JSON")
return {"nodes": nodes, "edges": edges}
def get_virus_metadata(kercode: str, df: pd.DataFrame) -> dict:
"""
Retrieve metadata for a virus by its kercode (Accession Number).
"""
metadata_dict = df.set_index("Accession Number")[["Genotype", "Host", "Country", "Isolate", "Year"]].to_dict('index')
metadata = metadata_dict.get(kercode)
if metadata:
logger.info(f"Retrieved metadata for {kercode}")
return metadata
else:
logger.warning(f"No metadata found for {kercode}")
return None
# FastAPI Endpoints
@app.post("/analyze", summary="Analyze an RNA sequence")
async def analyze_sequence(file: UploadFile = File(...), threshold: float = 90.0):
"""
Analyze an RNA sequence and return closest match and phylogenetic tree.
"""
if not file.filename.endswith((".fasta", ".fa")):
raise HTTPException(status_code=400, detail="Invalid file format. Use FASTA.")
try:
sequence = await file.read()
sequence = sequence.decode("utf-8")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to read file: {str(e)}")
if not validate_fasta(sequence):
raise HTTPException(status_code=400, detail="Invalid FASTA sequence")
try:
seq_io = StringIO(sequence)
user_seq = str(SeqIO.read(seq_io, "fasta").seq)
matched, closest_threshold = find_similar_sequences(user_seq, df, threshold)
matched_ids = [acc for acc, _ in matched]
query_id = write_fasta(df, matched, user_seq, FASTA_PATH)
alignment = align_sequences(FASTA_PATH)
tree = build_ml_tree(alignment, ALN_FILE, TREE_FILE)
simplified_tree = simplify_tree(tree, query_id, matched_ids)
tree_json = tree_to_json(simplified_tree, df, query_id, matched_ids)
closest_kercode = matched[0][0] if matched else query_id
closest_metadata = get_virus_metadata(closest_kercode, df)
return {
"closest_match": {
"kercode": closest_kercode,
"similarity": closest_threshold,
"metadata": closest_metadata
},
"tree": tree_json
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except RuntimeError as e:
raise HTTPException(status_code=500, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
@app.get("/virus/{kercode}", summary="Get virus metadata")
async def get_virus_metadata_endpoint(kercode: str):
"""
Retrieve metadata for a specific virus by kercode.
"""
metadata = get_virus_metadata(kercode, df)
if not metadata:
raise HTTPException(status_code=404, detail="Virus not found")
return metadata