Spaces:
Sleeping
Sleeping
File size: 3,740 Bytes
b52d4b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
import torch
import joblib
import numpy as np
from transformers import AutoTokenizer, AutoModelForMaskedLM
from Bio import SeqIO
import io
from sklearn.metrics import silhouette_score, silhouette_samples
import matplotlib.pyplot as plt
import seaborn as sns
import os
MODEL_NAME = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"
HDBSCAN_MODEL_PATH = "hdbscan_model.pkl"
MAX_LENGTH = 20
PLOTS_DIR = "plots"
os.makedirs(PLOTS_DIR, exist_ok=True)
print("Loading Transformer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
print("Transformer loaded.")
print("Loading HDBSCAN...")
clusterer = joblib.load(HDBSCAN_MODEL_PATH)
print("HDBSCAN loaded.")
def seq_to_kmers(seq, k=6):
seq = seq.upper()
return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)])
def analyze_fasta(fasta_bytes):
try:
# β
Decode bytes -> string -> StringIO (text mode)
fasta_str = fasta_bytes.decode("utf-8", errors="ignore")
fasta_io = io.StringIO(fasta_str)
sequences = []
ids = []
for record in SeqIO.parse(fasta_io, "fasta"):
ids.append(record.id)
sequences.append(str(record.seq))
if not sequences:
return {
"overall_silhouette": 0,
"results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": "No sequences found"}]
}, "plots/scatter.png", "plots/heatmap.png"
# β
Do clustering (same as before)
batch_kmers = [seq_to_kmers(s) for s in sequences]
inputs = tokenizer(
batch_kmers, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
last_hidden = outputs.hidden_states[-1]
mean_embeddings = last_hidden.mean(dim=1).cpu().numpy()
labels = clusterer.fit_predict(mean_embeddings)
strengths = [1.0 if l != -1 else 0.0 for l in labels]
valid_mask = np.array(labels) != -1
silhouette_avg, per_sample_sil = 0, None
if np.unique(np.array(labels)[valid_mask]).shape[0] > 1:
silhouette_avg = silhouette_score(mean_embeddings[valid_mask], np.array(labels)[valid_mask])
results = []
for i, seq_id in enumerate(ids):
result = {
"id": seq_id,
"cluster": int(labels[i]),
"confidence": round(float(strengths[i]), 3),
}
if labels[i] == -1:
result["note"] = "Potential novel/unknown sequence"
results.append(result)
return (
{"overall_silhouette": round(float(silhouette_avg), 3), "results": results},
"plots/scatter.png", # β
use existing saved scatter
"plots/heatmap.png" # β
use existing saved heatmap
)
except Exception as e:
return {
"overall_silhouette": 0,
"results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": f"Fallback: {str(e)}"}],
}, "plots/scatter.png", "plots/heatmap.png"
# Gradio UI
demo = gr.Interface(
fn=analyze_fasta,
inputs=gr.File(file_types=[".fasta"], type="binary"),
outputs=[gr.JSON(), gr.Image(), gr.Image()],
title="DNA Clustering Analyzer",
description="Upload a FASTA file β Get clustering results + scatter plot + heatmap."
)
if __name__ == "__main__":
demo.launch()
|