import streamlit as st import numpy as np import torch import torch.nn as nn import os MODEL_PT_PATH = "HantaSeqNet.pt" MAX_LENGTH = 512 SAMPLE_SEQUENCES = { "Bayou": "TGCTGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGATGATGATTGCTGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGATGATGATTGCTGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGATGATGAT", "Dobrava": "GATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGTTGATAATGATGATGATGATGATGATGATCATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGTTGATAATGATGATGATGATGATGATGATCATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGTTGATAATGATGATGATGATGATGATGATCAT", "Hantaan": "GCGACGACGACGACGACGACGACGACGACGATGACGACGATGACGATGACGATGATGATAATGATGATGATGATGATGATGATTATGATGAAGATGCGACGACGACGACGACGACGACGACGACGATGACGACGATGACGATGACGATGATGATAATGATGATGATGATGATGATGATTATGATGAAGATGCGACGACGACGACGACGACGACGACGACGATGACGACGATGACGATGACGATGATGATAATGATGATGATGATGATGATGATTATGATGAAGAT", "Puumala": "GACGCTGACGATGACGATTACGATGAGGATGTTGAGGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGACGCTGACGATGACGATTACGATGAGGATGTTGAGGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGACGCTGACGATGACGATTACGATGAGGATGTTGAGGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTAT", "Seoul": "ATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGAAGATGATGATGAATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGAAGATGATGATGAATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGAAGATGATGATGA", "Sin Nombre": "GCCGTTCCGATGCGGGCGACGACGACGAGGCGATGACTTCCATGACGATGACGACGATGACGACGACGACGACGACGAAGACGACGACGACGACGGCCGTTCCGATGCGGGCGACGACGACGAGGCGATGACTTCCATGACGATGACGACGATGACGACGACGACGACGACGAAGACGACGACGACGACGGCCGTTCCGATGCGGGCGACGACGACGAGGCGATGACTTCCATGACGATGACGACGATGACGACGACGACGACGACGAAGACGACGACGACGACG", } PAGE_CSS = """ """ class SmallANN(nn.Module): def __init__(self, input_dim, num_classes): super().__init__() self.network = nn.Sequential( nn.Linear(input_dim, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_classes) ) def forward(self, x): return self.network(x) @st.cache_resource def load_seq_model(): import os # HantaSeqNet.pt ada di Space repo, jadi path relatif langsung path = "HantaSeqNet.pt" bundle = torch.load(path, map_location="cpu", weights_only=False) ann = SmallANN(bundle["input_dim"], bundle["num_classes"]) ann.load_state_dict(bundle["model_state_dict"]) ann.eval() return ann, bundle["class_names"] @st.cache_resource(show_spinner="Memuat DNABERT-2") def load_dnabert(): from multimolecule import AutoTokenizer, DnaBert2Model from huggingface_hub import snapshot_download import os token = os.environ.get("HF_TOKEN") local_dir = snapshot_download( repo_id="RangGaraga/My_DNABert", token=token, local_dir="/tmp/dnabert2" # ← simpan ke /tmp agar lebih cepat ) tokenizer = AutoTokenizer.from_pretrained(local_dir, trust_remote_code=True) dnabert = DnaBert2Model.from_pretrained(local_dir, trust_remote_code=True) dnabert.eval() for p in dnabert.parameters(): p.requires_grad = False return tokenizer, dnabert def extract_embedding(seq, tokenizer, dnabert): inputs = tokenizer(seq, return_tensors="pt", truncation=True, padding="max_length", max_length=MAX_LENGTH) with torch.no_grad(): out = dnabert(**inputs) return out.last_hidden_state.mean(dim=1).squeeze(0).numpy() def clean_sequence(raw): lines = [] for line in raw.strip().splitlines(): line = line.strip() if not line.startswith(">"): lines.append(line.upper()) return "".join(lines) def conf_bar(label, pct, is_top): cls = "top" if is_top else "rest" badge = '▲ TOP' if is_top else "" return f"""
{label}{badge}
{pct:.1f}%
""" def render(goto): st.markdown(PAGE_CSS, unsafe_allow_html=True) # Inisialisasi session state untuk textarea dan selected sample if "seq_textarea" not in st.session_state: st.session_state["seq_textarea"] = "" if "selected_sample" not in st.session_state: st.session_state["selected_sample"] = None if st.button("← Kembali ke Home", key="back_seq"): goto("home") st.markdown(""" """, unsafe_allow_html=True) # INFO st.markdown('💡 Panduan', unsafe_allow_html=True) st.markdown("""
Apa itu sekuens DNA?
Sekuens DNA adalah rangkaian karakter A, T, G, C yang merepresentasikan kode genetik virus — biasanya hasil dari sequencing laboratorium dalam format FASTA. Jika tidak punya data sendiri, gunakan tombol contoh di bawah untuk mencoba sistem ini secara langsung.
""", unsafe_allow_html=True) # SAMPLE BUTTONS st.markdown('🧪 Contoh Sekuens', unsafe_allow_html=True) st.markdown( '

Klik salah satu kelas di bawah untuk mengisi sekuens contoh secara otomatis:

', unsafe_allow_html=True ) cols = st.columns(len(SAMPLE_SEQUENCES)) for i, cls_name in enumerate(SAMPLE_SEQUENCES): with cols[i]: is_sel = st.session_state["selected_sample"] == cls_name label = f"✓ {cls_name}" if is_sel else cls_name if st.button(label, key=f"sample_{cls_name}", use_container_width=True): st.session_state["seq_textarea"] = SAMPLE_SEQUENCES[cls_name] st.session_state["selected_sample"] = cls_name st.rerun() # TEXTAREA — key="seq_textarea" sesuai session state yang diupdate tombol st.markdown('✏️ Input Sekuens DNA', unsafe_allow_html=True) raw_input = st.text_area( label="Masukkan sekuens DNA (format FASTA atau plain sequence):", height=190, placeholder=( "Contoh plain sequence:\n" "ATGATGATGATGATGATGATGATGATGAT...\n\n" "Atau format FASTA:\n" ">SampleName|Seoul\n" "ATGATGATGATGATGATGATGATGATGAT..." ), key="seq_textarea", ) seq_len = len(clean_sequence(raw_input)) ok = seq_len >= 50 counter_cls = "ok" if ok else "warn" counter_ico = "✅" if ok else "⚠️" suffix = "" if ok else " — minimal 50 bp" st.markdown( f'' f'{counter_ico} Panjang sekuens: {seq_len:,} bp{suffix}', unsafe_allow_html=True ) # PREDICT BUTTON st.markdown("
", unsafe_allow_html=True) _, btn_col, _ = st.columns([1, 4, 1]) with btn_col: predict_clicked = st.button("Jalankan Klasifikasi", key="predict_seq", use_container_width=True) # HASIL if predict_clicked: sequence = clean_sequence(raw_input) if len(sequence) < 50: st.warning("Sekuens terlalu pendek. Masukkan minimal 50 base pair (bp).") return valid_chars = set("ATGCN") invalid = set(sequence) - valid_chars if invalid: st.warning( f"Karakter tidak valid: `{''.join(sorted(invalid))}`. " f"Hanya A, T, G, C, N yang diperbolehkan." ) return try: ann, class_names = load_seq_model() except FileNotFoundError: st.error(f"Model tidak ditemukan di: {MODEL_PT_PATH}") return try: tokenizer, dnabert = load_dnabert() except Exception as e: st.error(f"Gagal memuat DNABERT-2: {e}") return with st.spinner("Mengekstrak embedding dengan DNABERT-2"): emb = extract_embedding(sequence, tokenizer, dnabert) with torch.no_grad(): tensor = torch.tensor(emb, dtype=torch.float32).unsqueeze(0) probs = torch.softmax(ann(tensor), dim=1).squeeze(0).numpy() sorted_idx = np.argsort(probs)[::-1] top_class = class_names[sorted_idx[0]] top_pct = probs[sorted_idx[0]] * 100 all_bars = "".join( conf_bar(class_names[i], probs[i] * 100, rank == 0) for rank, i in enumerate(sorted_idx) ) st.markdown(f"""
Hasil Klasifikasi · HantaSeqNet
{top_class} Hantavirus
Sekuens yang diinput paling mirip dengan tipe {top_class} dengan confidence {top_pct:.1f}%.
Distribusi Confidence per Kelas
{all_bars}
⚠️ Perhatian: Hasil klasifikasi ini didasarkan pada similarity embedding sekuens terhadap data training. Konfirmasi dengan analisis filogenetik atau uji PCR spesifik tetap diperlukan. Untuk keperluan riset dan edukasi saja.
""", unsafe_allow_html=True)