HantaLytics / input_sequence.py
RangGaraga's picture
Update input_sequence.py
c507c00 verified
Raw
History Blame Contribute Delete
16 kB
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import os
MODEL_PT_PATH = "HantaSeqNet.pt"
MAX_LENGTH = 512
SAMPLE_SEQUENCES = {
"Bayou": "TGCTGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGATGATGATTGCTGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGATGATGATTGCTGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGATGATGAT",
"Dobrava": "GATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGTTGATAATGATGATGATGATGATGATGATCATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGTTGATAATGATGATGATGATGATGATGATCATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGTTGATAATGATGATGATGATGATGATGATCAT",
"Hantaan": "GCGACGACGACGACGACGACGACGACGACGATGACGACGATGACGATGACGATGATGATAATGATGATGATGATGATGATGATTATGATGAAGATGCGACGACGACGACGACGACGACGACGACGATGACGACGATGACGATGACGATGATGATAATGATGATGATGATGATGATGATTATGATGAAGATGCGACGACGACGACGACGACGACGACGACGATGACGACGATGACGATGACGATGATGATAATGATGATGATGATGATGATGATTATGATGAAGAT",
"Puumala": "GACGCTGACGATGACGATTACGATGAGGATGTTGAGGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGACGCTGACGATGACGATTACGATGAGGATGTTGAGGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGACGCTGACGATGACGATTACGATGAGGATGTTGAGGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTAT",
"Seoul": "ATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGAAGATGATGATGAATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGAAGATGATGATGAATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGAAGATGATGATGA",
"Sin Nombre": "GCCGTTCCGATGCGGGCGACGACGACGAGGCGATGACTTCCATGACGATGACGACGATGACGACGACGACGACGACGAAGACGACGACGACGACGGCCGTTCCGATGCGGGCGACGACGACGAGGCGATGACTTCCATGACGATGACGACGATGACGACGACGACGACGACGAAGACGACGACGACGACGGCCGTTCCGATGCGGGCGACGACGACGAGGCGATGACTTCCATGACGATGACGACGATGACGACGACGACGACGACGAAGACGACGACGACGACG",
}
PAGE_CSS = """
<style>
@import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=Inter:wght@300;400;500;600;700&display=swap');
html, body, [data-testid="stAppViewContainer"], [data-testid="stMain"] {
background-color: #f0f7ff !important;
font-family: 'Inter', sans-serif !important;
}
[data-testid="stSidebar"] { display: none !important; }
header[data-testid="stHeader"] { display: none !important; }
footer { display: none !important; }
.block-container {
padding: 2.5rem 3rem 4rem !important;
max-width: 1000px !important;
}
.stButton > button {
background: #2563eb !important;
color: #ffffff !important;
border: none !important;
font-family: 'Inter', sans-serif !important;
font-weight: 600 !important;
font-size: 0.9rem !important;
padding: 0.65rem 1.2rem !important;
border-radius: 8px !important;
width: 100% !important;
transition: background 0.2s !important;
}
.stButton > button:hover {
background: #1d4ed8 !important;
box-shadow: 0 4px 14px rgba(37,99,235,.25) !important;
}
.page-header {
background: linear-gradient(135deg, #eff6ff, #e0f2fe);
border: 1px solid #bae6fd;
border-radius: 14px;
padding: 1.6rem 2rem;
margin-bottom: 2rem;
display: flex;
align-items: center;
gap: 1.2rem;
}
.page-header-icon { font-size: 2.2rem; line-height: 1; }
.page-header-eyebrow {
font-family: 'Space Mono', monospace;
font-size: 0.67rem;
letter-spacing: 0.15em;
color: #0284c7;
text-transform: uppercase;
margin-bottom: 0.25rem;
}
.page-header-title {
font-family: 'Space Mono', monospace;
font-size: 1.3rem;
font-weight: 700;
color: #0f172a;
margin-bottom: 0.2rem;
}
.page-header-sub { font-size: 0.85rem; color: #64748b; }
.section-label {
font-family: 'Space Mono', monospace;
font-size: 0.67rem;
letter-spacing: 0.18em;
color: #0284c7;
text-transform: uppercase;
border-left: 3px solid #38bdf8;
padding-left: 0.7rem;
margin: 1.6rem 0 0.8rem;
display: block;
}
.info-box {
background: #f0f9ff;
border: 1px solid #bae6fd;
border-left: 4px solid #38bdf8;
border-radius: 8px;
padding: 1rem 1.2rem;
font-size: 0.87rem;
color: #0c4a6e;
line-height: 1.65;
}
.info-box strong { color: #0369a1; }
.info-box code {
background: #dbeafe;
padding: 0.1rem 0.35rem;
border-radius: 4px;
font-family: 'Space Mono', monospace;
font-size: 0.82rem;
color: #1d4ed8;
}
.sample-hint {
font-size: 0.82rem;
color: #64748b;
margin-bottom: 0.7rem;
font-weight: 500;
}
.seq-counter { font-family: 'Space Mono', monospace; font-size: 0.75rem; color: #64748b; margin-top: 0.35rem; display: inline-block; }
.seq-counter.ok { color: #16a34a; }
.seq-counter.warn { color: #d97706; }
div[data-testid="stTextArea"] label {
font-size: 0.84rem !important;
font-weight: 500 !important;
color: #374151 !important;
}
div[data-testid="stTextArea"] textarea {
background: #ffffff !important;
border: 1.5px solid #e2e8f0 !important;
border-radius: 8px !important;
color: #0f172a !important;
font-family: 'Space Mono', monospace !important;
font-size: 0.8rem !important;
line-height: 1.6 !important;
}
div[data-testid="stTextArea"] textarea:focus {
border-color: #38bdf8 !important;
box-shadow: 0 0 0 3px rgba(56,189,248,.15) !important;
}
.result-wrap {
background: linear-gradient(135deg, #eff6ff, #e0f2fe);
border: 1.5px solid #bae6fd;
border-radius: 14px;
padding: 2rem 2rem 1.6rem;
margin-top: 1.5rem;
position: relative;
overflow: hidden;
}
.result-wrap::before {
content: '';
position: absolute;
top: 0; left: 0; right: 0;
height: 4px;
background: linear-gradient(90deg, #3b82f6, #38bdf8, #06b6d4);
border-radius: 14px 14px 0 0;
}
.result-meta {
font-family: 'Space Mono', monospace;
font-size: 0.66rem;
letter-spacing: 0.13em;
text-transform: uppercase;
color: #64748b;
margin-bottom: 0.5rem;
}
.result-class {
font-family: 'Space Mono', monospace;
font-size: 1.9rem;
font-weight: 700;
color: #1d4ed8;
line-height: 1;
margin-bottom: 0.4rem;
}
.result-sub {
font-size: 0.87rem;
color: #475569;
margin-bottom: 1.6rem;
line-height: 1.55;
}
.result-sub strong { color: #1e40af; }
.conf-title {
font-family: 'Space Mono', monospace;
font-size: 0.66rem;
letter-spacing: 0.13em;
text-transform: uppercase;
color: #64748b;
margin-bottom: 0.75rem;
}
.conf-row { display: flex; align-items: center; gap: 0.9rem; margin-bottom: 0.65rem; }
.conf-lbl { font-size: 0.83rem; color: #374151; width: 90px; flex-shrink: 0; font-weight: 500; }
.conf-track { flex: 1; height: 9px; background: #dbeafe; border-radius: 100px; overflow: hidden; }
.conf-fill { height: 100%; border-radius: 100px; }
.conf-fill.top { background: linear-gradient(90deg, #2563eb, #38bdf8); }
.conf-fill.rest { background: #93c5fd; opacity: 0.45; }
.conf-pct { font-family: 'Space Mono', monospace; font-size: 0.83rem; color: #0f172a; width: 50px; text-align: right; flex-shrink: 0; font-weight: 600; }
.top-badge {
display: inline-block;
font-family: 'Space Mono', monospace;
font-size: 0.6rem;
padding: 0.1rem 0.4rem;
border-radius: 100px;
background: #dbeafe;
border: 1px solid #93c5fd;
color: #1d4ed8;
margin-left: 0.4rem;
vertical-align: middle;
}
.disclaimer {
margin-top: 1.4rem;
padding: 0.85rem 1rem;
background: rgba(251,191,36,.1);
border: 1px solid rgba(251,191,36,.35);
border-radius: 8px;
font-size: 0.79rem;
color: #92400e;
line-height: 1.55;
}
</style>
"""
class SmallANN(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, 256), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(256, 64), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(64, num_classes)
)
def forward(self, x):
return self.network(x)
@st.cache_resource
def load_seq_model():
import os
# HantaSeqNet.pt ada di Space repo, jadi path relatif langsung
path = "HantaSeqNet.pt"
bundle = torch.load(path, map_location="cpu", weights_only=False)
ann = SmallANN(bundle["input_dim"], bundle["num_classes"])
ann.load_state_dict(bundle["model_state_dict"])
ann.eval()
return ann, bundle["class_names"]
@st.cache_resource(show_spinner="Memuat DNABERT-2")
def load_dnabert():
from multimolecule import AutoTokenizer, DnaBert2Model
from huggingface_hub import snapshot_download
import os
token = os.environ.get("HF_TOKEN")
local_dir = snapshot_download(
repo_id="RangGaraga/My_DNABert",
token=token,
local_dir="/tmp/dnabert2" # ← simpan ke /tmp agar lebih cepat
)
tokenizer = AutoTokenizer.from_pretrained(local_dir, trust_remote_code=True)
dnabert = DnaBert2Model.from_pretrained(local_dir, trust_remote_code=True)
dnabert.eval()
for p in dnabert.parameters():
p.requires_grad = False
return tokenizer, dnabert
def extract_embedding(seq, tokenizer, dnabert):
inputs = tokenizer(seq, return_tensors="pt", truncation=True,
padding="max_length", max_length=MAX_LENGTH)
with torch.no_grad():
out = dnabert(**inputs)
return out.last_hidden_state.mean(dim=1).squeeze(0).numpy()
def clean_sequence(raw):
lines = []
for line in raw.strip().splitlines():
line = line.strip()
if not line.startswith(">"):
lines.append(line.upper())
return "".join(lines)
def conf_bar(label, pct, is_top):
cls = "top" if is_top else "rest"
badge = '<span class="top-badge">▲ TOP</span>' if is_top else ""
return f"""
<div class="conf-row">
<div class="conf-lbl">{label}{badge}</div>
<div class="conf-track"><div class="conf-fill {cls}" style="width:{pct:.1f}%"></div></div>
<div class="conf-pct">{pct:.1f}%</div>
</div>"""
def render(goto):
st.markdown(PAGE_CSS, unsafe_allow_html=True)
# Inisialisasi session state untuk textarea dan selected sample
if "seq_textarea" not in st.session_state:
st.session_state["seq_textarea"] = ""
if "selected_sample" not in st.session_state:
st.session_state["selected_sample"] = None
if st.button("← Kembali ke Home", key="back_seq"):
goto("home")
st.markdown("""
<div class="page-header">
<div class="page-header-icon">🧬</div>
<div>
<div class="page-header-eyebrow">Model 02 · HantaSeqNet</div>
<div class="page-header-title">Input Sekuens DNA</div>
<div class="page-header-sub">Klasifikasi tipe Hantavirus berdasarkan sekuens genom menggunakan DNABERT-2 + ANN.</div>
</div>
</div>
""", unsafe_allow_html=True)
# INFO
st.markdown('<span class="section-label">💡 Panduan</span>', unsafe_allow_html=True)
st.markdown("""
<div class="info-box">
<strong>Apa itu sekuens DNA?</strong><br>
Sekuens DNA adalah rangkaian karakter <code>A</code>, <code>T</code>, <code>G</code>, <code>C</code>
yang merepresentasikan kode genetik virus — biasanya hasil dari sequencing laboratorium
dalam format <strong>FASTA</strong>. Jika tidak punya data sendiri,
gunakan tombol contoh di bawah untuk mencoba sistem ini secara langsung.
</div>
""", unsafe_allow_html=True)
# SAMPLE BUTTONS
st.markdown('<span class="section-label">🧪 Contoh Sekuens</span>', unsafe_allow_html=True)
st.markdown(
'<p class="sample-hint">Klik salah satu kelas di bawah untuk mengisi sekuens contoh secara otomatis:</p>',
unsafe_allow_html=True
)
cols = st.columns(len(SAMPLE_SEQUENCES))
for i, cls_name in enumerate(SAMPLE_SEQUENCES):
with cols[i]:
is_sel = st.session_state["selected_sample"] == cls_name
label = f"✓ {cls_name}" if is_sel else cls_name
if st.button(label, key=f"sample_{cls_name}", use_container_width=True):
st.session_state["seq_textarea"] = SAMPLE_SEQUENCES[cls_name]
st.session_state["selected_sample"] = cls_name
st.rerun()
# TEXTAREA — key="seq_textarea" sesuai session state yang diupdate tombol
st.markdown('<span class="section-label">✏️ Input Sekuens DNA</span>', unsafe_allow_html=True)
raw_input = st.text_area(
label="Masukkan sekuens DNA (format FASTA atau plain sequence):",
height=190,
placeholder=(
"Contoh plain sequence:\n"
"ATGATGATGATGATGATGATGATGATGAT...\n\n"
"Atau format FASTA:\n"
">SampleName|Seoul\n"
"ATGATGATGATGATGATGATGATGATGAT..."
),
key="seq_textarea",
)
seq_len = len(clean_sequence(raw_input))
ok = seq_len >= 50
counter_cls = "ok" if ok else "warn"
counter_ico = "✅" if ok else "⚠️"
suffix = "" if ok else " — minimal 50 bp"
st.markdown(
f'<span class="seq-counter {counter_cls}">'
f'{counter_ico} Panjang sekuens: <strong>{seq_len:,} bp</strong>{suffix}</span>',
unsafe_allow_html=True
)
# PREDICT BUTTON
st.markdown("<div style='height:0.6rem'></div>", unsafe_allow_html=True)
_, btn_col, _ = st.columns([1, 4, 1])
with btn_col:
predict_clicked = st.button("Jalankan Klasifikasi", key="predict_seq", use_container_width=True)
# HASIL
if predict_clicked:
sequence = clean_sequence(raw_input)
if len(sequence) < 50:
st.warning("Sekuens terlalu pendek. Masukkan minimal 50 base pair (bp).")
return
valid_chars = set("ATGCN")
invalid = set(sequence) - valid_chars
if invalid:
st.warning(
f"Karakter tidak valid: `{''.join(sorted(invalid))}`. "
f"Hanya A, T, G, C, N yang diperbolehkan."
)
return
try:
ann, class_names = load_seq_model()
except FileNotFoundError:
st.error(f"Model tidak ditemukan di: {MODEL_PT_PATH}")
return
try:
tokenizer, dnabert = load_dnabert()
except Exception as e:
st.error(f"Gagal memuat DNABERT-2: {e}")
return
with st.spinner("Mengekstrak embedding dengan DNABERT-2"):
emb = extract_embedding(sequence, tokenizer, dnabert)
with torch.no_grad():
tensor = torch.tensor(emb, dtype=torch.float32).unsqueeze(0)
probs = torch.softmax(ann(tensor), dim=1).squeeze(0).numpy()
sorted_idx = np.argsort(probs)[::-1]
top_class = class_names[sorted_idx[0]]
top_pct = probs[sorted_idx[0]] * 100
all_bars = "".join(
conf_bar(class_names[i], probs[i] * 100, rank == 0)
for rank, i in enumerate(sorted_idx)
)
st.markdown(f"""
<div class="result-wrap">
<div class="result-meta">Hasil Klasifikasi · HantaSeqNet</div>
<div class="result-class">{top_class} Hantavirus</div>
<div class="result-sub">
Sekuens yang diinput paling mirip dengan tipe
<strong>{top_class}</strong> dengan confidence
<strong>{top_pct:.1f}%</strong>.
</div>
<div class="conf-title">Distribusi Confidence per Kelas</div>
{all_bars}
<div class="disclaimer">
⚠️ <strong>Perhatian:</strong> Hasil klasifikasi ini didasarkan pada similarity embedding
sekuens terhadap data training. Konfirmasi dengan analisis filogenetik atau uji PCR spesifik
tetap diperlukan. Untuk keperluan riset dan edukasi saja.
</div>
</div>
""", unsafe_allow_html=True)