File size: 3,740 Bytes
b52d4b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import torch
import joblib
import numpy as np
from transformers import AutoTokenizer, AutoModelForMaskedLM
from Bio import SeqIO
import io
from sklearn.metrics import silhouette_score, silhouette_samples
import matplotlib.pyplot as plt
import seaborn as sns
import os

MODEL_NAME = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"
HDBSCAN_MODEL_PATH = "hdbscan_model.pkl"
MAX_LENGTH = 20

PLOTS_DIR = "plots"
os.makedirs(PLOTS_DIR, exist_ok=True)

print("Loading Transformer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
print("Transformer loaded.")

print("Loading HDBSCAN...")
clusterer = joblib.load(HDBSCAN_MODEL_PATH)
print("HDBSCAN loaded.")

def seq_to_kmers(seq, k=6):
    seq = seq.upper()
    return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)])

def analyze_fasta(fasta_bytes):
    try:
        # βœ… Decode bytes -> string -> StringIO (text mode)
        fasta_str = fasta_bytes.decode("utf-8", errors="ignore")
        fasta_io = io.StringIO(fasta_str)

        sequences = []
        ids = []
        for record in SeqIO.parse(fasta_io, "fasta"):
            ids.append(record.id)
            sequences.append(str(record.seq))

        if not sequences:
            return {
                "overall_silhouette": 0,
                "results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": "No sequences found"}]
            }, "plots/scatter.png", "plots/heatmap.png"

        # βœ… Do clustering (same as before)
        batch_kmers = [seq_to_kmers(s) for s in sequences]
        inputs = tokenizer(
            batch_kmers, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            last_hidden = outputs.hidden_states[-1]
            mean_embeddings = last_hidden.mean(dim=1).cpu().numpy()

        labels = clusterer.fit_predict(mean_embeddings)
        strengths = [1.0 if l != -1 else 0.0 for l in labels]

        valid_mask = np.array(labels) != -1
        silhouette_avg, per_sample_sil = 0, None
        if np.unique(np.array(labels)[valid_mask]).shape[0] > 1:
            silhouette_avg = silhouette_score(mean_embeddings[valid_mask], np.array(labels)[valid_mask])

        results = []
        for i, seq_id in enumerate(ids):
            result = {
                "id": seq_id,
                "cluster": int(labels[i]),
                "confidence": round(float(strengths[i]), 3),
            }
            if labels[i] == -1:
                result["note"] = "Potential novel/unknown sequence"
            results.append(result)

        return (
            {"overall_silhouette": round(float(silhouette_avg), 3), "results": results},
            "plots/scatter.png",   # βœ… use existing saved scatter
            "plots/heatmap.png"    # βœ… use existing saved heatmap
        )

    except Exception as e:
        return {
            "overall_silhouette": 0,
            "results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": f"Fallback: {str(e)}"}],
        }, "plots/scatter.png", "plots/heatmap.png"


# Gradio UI
demo = gr.Interface(
    fn=analyze_fasta,
    inputs=gr.File(file_types=[".fasta"], type="binary"),
    outputs=[gr.JSON(), gr.Image(), gr.Image()],
    title="DNA Clustering Analyzer",
    description="Upload a FASTA file β†’ Get clustering results + scatter plot + heatmap."
)

if __name__ == "__main__":
    demo.launch()