import streamlit as st import numpy as np import torch import torch.nn as nn import os MODEL_PT_PATH = "HantaSeqNet.pt" MAX_LENGTH = 512 SAMPLE_SEQUENCES = { "Bayou": "TGCTGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGATGATGATTGCTGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGATGATGATTGCTGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGATGATGAT", "Dobrava": "GATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGTTGATAATGATGATGATGATGATGATGATCATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGTTGATAATGATGATGATGATGATGATGATCATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGTTGATAATGATGATGATGATGATGATGATCAT", "Hantaan": "GCGACGACGACGACGACGACGACGACGACGATGACGACGATGACGATGACGATGATGATAATGATGATGATGATGATGATGATTATGATGAAGATGCGACGACGACGACGACGACGACGACGACGATGACGACGATGACGATGACGATGATGATAATGATGATGATGATGATGATGATTATGATGAAGATGCGACGACGACGACGACGACGACGACGACGATGACGACGATGACGATGACGATGATGATAATGATGATGATGATGATGATGATTATGATGAAGAT", "Puumala": "GACGCTGACGATGACGATTACGATGAGGATGTTGAGGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGACGCTGACGATGACGATTACGATGAGGATGTTGAGGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTATGACGCTGACGATGACGATTACGATGAGGATGTTGAGGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATTAT", "Seoul": "ATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGAAGATGATGATGAATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGAAGATGATGATGAATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGATGAGGATGATGATGATGATGATGATGATGATGTTGATGAAGATGATGATGA", "Sin Nombre": "GCCGTTCCGATGCGGGCGACGACGACGAGGCGATGACTTCCATGACGATGACGACGATGACGACGACGACGACGACGAAGACGACGACGACGACGGCCGTTCCGATGCGGGCGACGACGACGAGGCGATGACTTCCATGACGATGACGACGATGACGACGACGACGACGACGAAGACGACGACGACGACGGCCGTTCCGATGCGGGCGACGACGACGAGGCGATGACTTCCATGACGATGACGACGATGACGACGACGACGACGACGAAGACGACGACGACGACG", } PAGE_CSS = """ """ class SmallANN(nn.Module): def __init__(self, input_dim, num_classes): super().__init__() self.network = nn.Sequential( nn.Linear(input_dim, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_classes) ) def forward(self, x): return self.network(x) @st.cache_resource def load_seq_model(): import os # HantaSeqNet.pt ada di Space repo, jadi path relatif langsung path = "HantaSeqNet.pt" bundle = torch.load(path, map_location="cpu", weights_only=False) ann = SmallANN(bundle["input_dim"], bundle["num_classes"]) ann.load_state_dict(bundle["model_state_dict"]) ann.eval() return ann, bundle["class_names"] @st.cache_resource(show_spinner="Memuat DNABERT-2") def load_dnabert(): from multimolecule import AutoTokenizer, DnaBert2Model from huggingface_hub import snapshot_download import os token = os.environ.get("HF_TOKEN") local_dir = snapshot_download( repo_id="RangGaraga/My_DNABert", token=token, local_dir="/tmp/dnabert2" # ← simpan ke /tmp agar lebih cepat ) tokenizer = AutoTokenizer.from_pretrained(local_dir, trust_remote_code=True) dnabert = DnaBert2Model.from_pretrained(local_dir, trust_remote_code=True) dnabert.eval() for p in dnabert.parameters(): p.requires_grad = False return tokenizer, dnabert def extract_embedding(seq, tokenizer, dnabert): inputs = tokenizer(seq, return_tensors="pt", truncation=True, padding="max_length", max_length=MAX_LENGTH) with torch.no_grad(): out = dnabert(**inputs) return out.last_hidden_state.mean(dim=1).squeeze(0).numpy() def clean_sequence(raw): lines = [] for line in raw.strip().splitlines(): line = line.strip() if not line.startswith(">"): lines.append(line.upper()) return "".join(lines) def conf_bar(label, pct, is_top): cls = "top" if is_top else "rest" badge = '▲ TOP' if is_top else "" return f"""
A, T, G, C
yang merepresentasikan kode genetik virus — biasanya hasil dari sequencing laboratorium
dalam format FASTA. Jika tidak punya data sendiri,
gunakan tombol contoh di bawah untuk mencoba sistem ini secara langsung.
Klik salah satu kelas di bawah untuk mengisi sekuens contoh secara otomatis:
', unsafe_allow_html=True ) cols = st.columns(len(SAMPLE_SEQUENCES)) for i, cls_name in enumerate(SAMPLE_SEQUENCES): with cols[i]: is_sel = st.session_state["selected_sample"] == cls_name label = f"✓ {cls_name}" if is_sel else cls_name if st.button(label, key=f"sample_{cls_name}", use_container_width=True): st.session_state["seq_textarea"] = SAMPLE_SEQUENCES[cls_name] st.session_state["selected_sample"] = cls_name st.rerun() # TEXTAREA — key="seq_textarea" sesuai session state yang diupdate tombol st.markdown('✏️ Input Sekuens DNA', unsafe_allow_html=True) raw_input = st.text_area( label="Masukkan sekuens DNA (format FASTA atau plain sequence):", height=190, placeholder=( "Contoh plain sequence:\n" "ATGATGATGATGATGATGATGATGATGAT...\n\n" "Atau format FASTA:\n" ">SampleName|Seoul\n" "ATGATGATGATGATGATGATGATGATGAT..." ), key="seq_textarea", ) seq_len = len(clean_sequence(raw_input)) ok = seq_len >= 50 counter_cls = "ok" if ok else "warn" counter_ico = "✅" if ok else "⚠️" suffix = "" if ok else " — minimal 50 bp" st.markdown( f'' f'{counter_ico} Panjang sekuens: {seq_len:,} bp{suffix}', unsafe_allow_html=True ) # PREDICT BUTTON st.markdown("", unsafe_allow_html=True) _, btn_col, _ = st.columns([1, 4, 1]) with btn_col: predict_clicked = st.button("Jalankan Klasifikasi", key="predict_seq", use_container_width=True) # HASIL if predict_clicked: sequence = clean_sequence(raw_input) if len(sequence) < 50: st.warning("Sekuens terlalu pendek. Masukkan minimal 50 base pair (bp).") return valid_chars = set("ATGCN") invalid = set(sequence) - valid_chars if invalid: st.warning( f"Karakter tidak valid: `{''.join(sorted(invalid))}`. " f"Hanya A, T, G, C, N yang diperbolehkan." ) return try: ann, class_names = load_seq_model() except FileNotFoundError: st.error(f"Model tidak ditemukan di: {MODEL_PT_PATH}") return try: tokenizer, dnabert = load_dnabert() except Exception as e: st.error(f"Gagal memuat DNABERT-2: {e}") return with st.spinner("Mengekstrak embedding dengan DNABERT-2"): emb = extract_embedding(sequence, tokenizer, dnabert) with torch.no_grad(): tensor = torch.tensor(emb, dtype=torch.float32).unsqueeze(0) probs = torch.softmax(ann(tensor), dim=1).squeeze(0).numpy() sorted_idx = np.argsort(probs)[::-1] top_class = class_names[sorted_idx[0]] top_pct = probs[sorted_idx[0]] * 100 all_bars = "".join( conf_bar(class_names[i], probs[i] * 100, rank == 0) for rank, i in enumerate(sorted_idx) ) st.markdown(f"""