import os import json import time from typing import List, Dict, Tuple import streamlit as st import requests # Guard imports for optional dependencies try: import torch from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM TORCH_AVAILABLE = True except Exception: TORCH_AVAILABLE = False try: from datasets import load_dataset DATASETS_AVAILABLE = True except Exception: DATASETS_AVAILABLE = False try: from sentence_transformers import SentenceTransformer SENTENCE_TRANSFORMERS_AVAILABLE = True except Exception: SENTENCE_TRANSFORMERS_AVAILABLE = False try: import faiss FAISS_AVAILABLE = True except Exception: FAISS_AVAILABLE = False try: from Bio import SeqIO BIOPYTHON_AVAILABLE = True except Exception: BIOPYTHON_AVAILABLE = False # Constants APP_TITLE = "BioSeq Chat: Protein & DNA Assistant" DISCLAIMER = ( "This tool is for research/education and is not a medical device. " "Do not use outputs for diagnosis or treatment decisions." ) # --------------- Helper Functions --------------- def get_secret(name: str, fallback: str = "") -> str: """Get secret from st.secrets, environment, or fallback""" try: if hasattr(st, 'secrets'): return st.secrets.get(name, os.environ.get(name, fallback)) except: pass return os.environ.get(name, fallback) def brave_search(query: str, count: int = 5) -> List[Dict]: """Search using Brave Search API""" key = get_secret("BRAVE_API_KEY", "") if not key: return [{"title": "BRAVE_API_KEY is missing", "url": "", "snippet": "Set BRAVE_API_KEY in Space secrets or sidebar to enable web search."}] url = "https://api.search.brave.com/res/v1/web/search" headers = { "Accept": "application/json", "X-Subscription-Token": key, "Accept-Encoding": "gzip" } params = {"q": query, "count": count, "country": "us"} try: r = requests.get(url, headers=headers, params=params, timeout=15) r.raise_for_status() data = r.json() results = [] for item in data.get("web", {}).get("results", [])[:count]: results.append({ "title": item.get("title", ""), "url": item.get("url", ""), "snippet": item.get("description", ""), }) return results if results else [{"title": "No results", "url": "", "snippet": "Query returned no results."}] except Exception as e: return [{"title": "Search error", "url": "", "snippet": str(e)}] def call_fireworks(messages: List[Dict], temperature: float = 0.6, max_tokens: int = 1024) -> str: """Call Fireworks AI chat completion API""" api_key = get_secret("FIREWORKS_API_KEY", "") if not api_key: return "FIREWORKS_API_KEY is missing. Set it in Secrets or the sidebar." url = "https://api.fireworks.ai/inference/v1/chat/completions" payload = { "model": "accounts/fireworks/models/llama-v3p1-70b-instruct", "max_tokens": max_tokens, "top_p": 1, "top_k": 40, "presence_penalty": 0, "frequency_penalty": 0, "temperature": temperature, "messages": messages } headers = { "Accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" } try: r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=60) r.raise_for_status() data = r.json() return data["choices"][0]["message"]["content"] except Exception as e: return f"[Fireworks API error] {e}" def load_text_from_file(upload) -> str: """Load text from uploaded file""" name = upload.name.lower() content = upload.read() try: text = content.decode("utf-8", errors="ignore") except: text = str(content) # FASTA file handling if name.endswith((".fa", ".fasta", ".faa", ".fna")) and BIOPYTHON_AVAILABLE: upload.seek(0) try: records = list(SeqIO.parse(upload, "fasta")) seqs = [] for r in records: seqs.append(f">{r.id}\n{str(r.seq)}") return "\n\n".join(seqs) except: pass return text def build_vector_index(texts: List[str], embedder_name: str = "sentence-transformers/all-MiniLM-L6-v2"): """Build FAISS vector index from texts""" if not SENTENCE_TRANSFORMERS_AVAILABLE or not FAISS_AVAILABLE: return None, None, None try: model = SentenceTransformer(embedder_name) emb = model.encode(texts, show_progress_bar=False, normalize_embeddings=True) dim = emb.shape[1] index = faiss.IndexFlatIP(dim) index.add(emb.astype("float32")) return index, emb, model except Exception as e: st.warning(f"Failed to build index: {e}") return None, None, None def search_index(query: str, index, model, texts: List[str], k: int = 4): """Search vector index""" if index is None or model is None: return [] try: q = model.encode([query], normalize_embeddings=True) D, I = index.search(q.astype("float32"), k) hits = [] for idx, score in zip(I[0], D[0]): if 0 <= idx < len(texts): hits.append({"score": float(score), "text": texts[idx]}) return hits except: return [] def esm2_embed(seq: str, model_id: str = "facebook/esm2_t6_8M_UR50D") -> Dict: """Generate ESM-2 embedding for protein sequence""" if not TORCH_AVAILABLE: return {"error": "Transformers/torch not available. Please wait for dependencies to install."} try: from transformers import AutoTokenizer, AutoModelForMaskedLM import torch tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForMaskedLM.from_pretrained(model_id, trust_remote_code=True) model.eval() with torch.no_grad(): toks = tokenizer(seq, return_tensors="pt") out = model(**toks, output_hidden_states=True) hidden = out.hidden_states[-1].mean(dim=1).squeeze(0) vec = hidden.detach().cpu().numpy() return {"embedding": vec.tolist(), "hidden_size": vec.shape[0]} except Exception as e: return {"error": str(e)} def dna_embed(seq: str, model_id: str = "zhihan1996/DNABERT-2-117M") -> Dict: """Generate DNA embedding""" if not TORCH_AVAILABLE: return {"error": "Transformers/torch not available. Please wait for dependencies to install."} try: from transformers import AutoTokenizer, AutoModel import torch tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) model = AutoModel.from_pretrained(model_id, trust_remote_code=True) model.eval() with torch.no_grad(): toks = tokenizer(seq, return_tensors="pt", truncation=True, max_length=4096) out = model(**toks, output_hidden_states=True) hidden = out.last_hidden_state.mean(dim=1).squeeze(0) vec = hidden.detach().cpu().numpy() return {"embedding": vec.tolist(), "hidden_size": vec.shape[0]} except Exception as e: return {"error": str(e)} def chunk_text(text: str, chunk_size: int = 1200, overlap: int = 200) -> List[str]: """Chunk text with overlap""" text = text.replace("\r\n", "\n") chunks = [] start = 0 while start < len(text): end = min(len(text), start + chunk_size) chunks.append(text[start:end]) if end >= len(text): break start = end - overlap return chunks def build_context(user_query: str, index, index_model, docs: List[str], loaded_datasets: List, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]: """Build context from various sources""" pieces = [] sources = [] # From uploaded files if index is not None and index_model is not None and docs: hits = search_index(user_query, index, index_model, docs, k=4) for h in hits: pieces.append(f"[FILE] {h['text'][:800]}") sources.append({"type": "file", "text": h["text"][:200]}) # From datasets for rid, sample in loaded_datasets: if sample: pieces.append(f"[DATASET {rid}] {sample}") sources.append({"type": "dataset", "id": rid}) # From web if use_web: results = brave_search(user_query, count=web_k) for r in results: snippet = r.get("snippet", "") url = r.get("url", "") title = r.get("title", "") pieces.append(f"[WEB] {title}\n{snippet}\n{url}") sources.append({"type": "web", "title": title, "url": url}) context = "\n\n---\n\n".join(pieces)[:6000] return context, sources def chat_answer(user_query: str, index, index_model, docs: List[str], loaded_datasets: List, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]: """Generate chat answer with context""" context, sources = build_context(user_query, index, index_model, docs, loaded_datasets, use_web, web_k) system = ( "You are a concise, careful bioinformatics assistant for protein and DNA. " "Answer with factual, verifiable statements. " "When uncertain, say so briefly. " "Never give medical advice. Provide short references as plain URLs or titles if present in context. " "User uploads and web/dataset snippets are provided as context below." ) prompt = f"Context:\n{context}\n\nUser question:\n{user_query}\n\nAnswer in Korean if the user used Korean; otherwise match user language." messages = [ {"role": "system", "content": system}, {"role": "user", "content": prompt} ] answer = call_fireworks(messages, temperature=0.4, max_tokens=1200) return answer, sources # --------------- Streamlit UI --------------- st.set_page_config(page_title=APP_TITLE, page_icon="🧬", layout="wide") st.title(APP_TITLE) st.caption(DISCLAIMER) # Check dependencies status if not TORCH_AVAILABLE: st.warning("⏳ PyTorch is being installed. Some features may be unavailable initially. Please refresh in a minute.") # Initialize session state if 'docs' not in st.session_state: st.session_state.docs = [] if 'index' not in st.session_state: st.session_state.index = None if 'index_model' not in st.session_state: st.session_state.index_model = None if 'loaded_datasets' not in st.session_state: st.session_state.loaded_datasets = [] # Sidebar configuration with st.sidebar: st.header("Keys and settings") fw_key = st.text_input("FIREWORKS_API_KEY", value=get_secret("FIREWORKS_API_KEY", ""), type="password") brave_key = st.text_input("BRAVE_API_KEY", value=get_secret("BRAVE_API_KEY", ""), type="password") if fw_key: os.environ["FIREWORKS_API_KEY"] = fw_key if brave_key: os.environ["BRAVE_API_KEY"] = brave_key st.markdown("### Model selections") esm2_id = st.text_input( "Protein model (ESM-2)", value="facebook/esm2_t6_8M_UR50D", help="Try larger models like facebook/esm2_t33_650M_UR50D if resources allow." ) dna_id = st.text_input( "DNA model", value="zhihan1996/DNABERT-2-117M", help="Alternative: InstaDeepAI/nucleotide-transformer-500m-human-ref" ) use_web = st.checkbox("Use Brave web search for context", value=True) web_k = st.slider("Web results", 1, 10, 4) st.markdown("### Datasets (optional)") dataset_ids = st.text_area( "Datasets to load (one per line)", value="", help="Enter Hugging Face dataset repo ids, e.g., 'genomics-benchmark/jaspar_motifs'" ) st.divider() st.markdown("Files you upload are indexed locally and used for answers.") # Main tabs tabs = st.tabs(["Chat", "Protein", "DNA", "Examples", "About"]) # File upload section with st.expander("Upload files for context (txt/csv/json/fasta/vcf)", expanded=True): uploads = st.file_uploader( "Add files", type=["txt", "md", "csv", "tsv", "json", "fa", "fasta", "faa", "fna", "vcf"], accept_multiple_files=True, key="file_uploader" ) if uploads: docs = [] for up in uploads: try: txt = load_text_from_file(up) docs.extend(chunk_text(txt)) except Exception as e: st.warning(f"Failed to read {up.name}: {e}") st.session_state.docs = docs st.caption(f"Indexed chunks: {len(docs)}") # Build index if docs available if docs and SENTENCE_TRANSFORMERS_AVAILABLE and FAISS_AVAILABLE: with st.spinner("Building vector index..."): index, emb, index_model = build_vector_index(docs) st.session_state.index = index st.session_state.index_model = index_model else: st.caption("No files uploaded yet") # Load datasets if specified if dataset_ids.strip() and DATASETS_AVAILABLE: dataset_list = [x.strip() for x in dataset_ids.splitlines() if x.strip()] if dataset_list != [d[0] for d in st.session_state.loaded_datasets]: st.session_state.loaded_datasets = [] for rid in dataset_list: with st.spinner(f"Loading dataset {rid}..."): try: ds = load_dataset(rid) sample = "" for split in ds.keys(): try: row = ds[split][0] sample = json.dumps(row, ensure_ascii=False)[:500] break except: pass st.session_state.loaded_datasets.append((rid, sample)) st.success(f"Loaded {rid}") except Exception as e: st.error(f"Failed to load {rid}: {e}") # Chat tab with tabs[0]: st.subheader("Chat") q = st.text_area("Ask a question about protein/DNA", value="ESM-2 임베딩은 단백질 기능 해석에 어떻게 도움되나요?") if st.button("Answer", type="primary"): with st.spinner("Thinking..."): ans, srcs = chat_answer( q, st.session_state.index, st.session_state.index_model, st.session_state.docs, st.session_state.loaded_datasets, use_web, web_k ) st.write(ans) if srcs: st.markdown("#### Sources") for s in srcs: if s.get("type") == "web" and s.get("url"): st.markdown(f"- {s.get('title', 'web')}: {s.get('url')}") elif s.get("type") == "dataset": st.markdown(f"- dataset: {s.get('id')}") elif s.get("type") == "file": snippet = s.get("text", "") st.markdown(f"- file snippet: {snippet[:120]}...") # Protein tab with tabs[1]: st.subheader("Protein analysis") seq = st.text_area("Protein sequence (amino acids only)", value="MKTIIALSYIFCLVFADYKDDDDK") col1, col2 = st.columns(2) with col1: st.caption("ESM-2 embedding") if st.button("Run ESM-2", key="run_esm2"): with st.spinner("Computing ESM-2 embedding..."): out = esm2_embed(seq.strip(), esm2_id) if "error" in out: st.error(out["error"]) else: st.success(f"Vector size: {out['hidden_size']}") st.json({"embedding_preview": out["embedding"][:8]}) with col2: st.caption("Quick stats") s = seq.replace("\n", "").replace(" ", "").upper() length = len(s) aa_set = sorted(set(list(s))) st.write(f"Length: {length}") st.write(f"Unique AAs: {''.join(aa_set)[:30]}") # DNA tab with tabs[2]: st.subheader("DNA analysis") dseq = st.text_area("DNA sequence (ACGT only)", value="ATGCGTACGTAGCTAGCTAGCTAGGCTAGC") col3, col4 = st.columns(2) with col3: st.caption("DNA embedding") if st.button("Run DNA embed", key="run_dna"): with st.spinner("Computing DNA embedding..."): out = dna_embed(dseq.strip(), dna_id) if "error" in out: st.error(out["error"]) else: st.success(f"Vector size: {out['hidden_size']}") st.json({"embedding_preview": out["embedding"][:8]}") with col4: st.caption("GC content") s = dseq.upper().replace("N", "").replace(" ", "").replace("\n", "") if len(s) > 0: gc = (s.count("G") + s.count("C")) / len(s) else: gc = 0 st.write(f"Length: {len(s)}") st.write(f"GC: {gc:.3f}") # Examples tab with tabs[3]: st.subheader("Examples") st.markdown("### Example questions you can ask:") st.markdown("- 업로드한 FASTA에서 특정 단백질의 기능 요약과 변이 영향 질문") st.markdown("- DNA 서열에서 프로모터 가능성과 전사인자 모티프 관련 근거 요청") st.markdown("- Enzyme active site 근접 변이의 리스크 해석 (연구 관점)") st.markdown("- ENCODE/UniProt/AlphaFold 개념 설명 요청") st.markdown("- RAG 기반으로 문서 인용과 함께 간략 답변 요청") # About tab with tabs[4]: st.subheader("About this Space") st.write("**Models suggested:**") st.write("- ESM-2 for proteins") st.write("- DNABERT-2 or Nucleotide Transformer for DNA") st.write("") st.write("**Common datasets:**") st.write("- UniProtKB, AlphaFoldDB, ENCODE, JASPAR, ClinVar") st.write("") st.write("**Features:**") st.write("- Web search powered by Brave Search API") st.write("- LLM powered by Fireworks AI") st.write("- Vector search with FAISS") st.write("") st.info(DISCLAIMER)