import os import json from typing import List, Dict, Tuple import streamlit as st import requests # 선택적 의존성 가드 try: import torch TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False print("[WARNING] torch not available") try: from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM TRANSFORMERS_AVAILABLE = True except ImportError: TRANSFORMERS_AVAILABLE = False print("[WARNING] transformers not available") try: from datasets import load_dataset DATASETS_AVAILABLE = True except ImportError: DATASETS_AVAILABLE = False print("[WARNING] datasets not available") try: from sentence_transformers import SentenceTransformer SENTENCE_TRANSFORMERS_AVAILABLE = True except ImportError: SENTENCE_TRANSFORMERS_AVAILABLE = False print("[WARNING] sentence_transformers not available") try: import faiss FAISS_AVAILABLE = True except ImportError: FAISS_AVAILABLE = False print("[WARNING] faiss not available") try: from Bio import SeqIO BIOPYTHON_AVAILABLE = True except ImportError: BIOPYTHON_AVAILABLE = False print("[WARNING] biopython not available") # 상수 APP_TITLE = "BioSeq Chat: Protein & DNA Assistant" DISCLAIMER = "This tool is for research/education and is not a medical device. Do not use outputs for diagnosis or treatment decisions." # --------------- Helper Functions --------------- def get_secret(name: str, fallback: str = "") -> str: """Get secret from st.secrets or environment""" try: # Streamlit secrets if hasattr(st, 'secrets') and name in st.secrets: return st.secrets[name] except: pass # Environment variable return os.environ.get(name, fallback) def brave_search(query: str, count: int = 5) -> List[Dict]: """Brave Search API""" key = get_secret("BRAVE_API_KEY", "") if not key: return [{ "title": "BRAVE_API_KEY missing", "url": "", "snippet": "Set BRAVE_API_KEY in Space secrets or sidebar" }] url = "https://api.search.brave.com/res/v1/web/search" headers = { "Accept": "application/json", "X-Subscription-Token": key } params = {"q": query, "count": count} try: r = requests.get(url, headers=headers, params=params, timeout=15) r.raise_for_status() data = r.json() results = [] for item in data.get("web", {}).get("results", [])[:count]: results.append({ "title": item.get("title", ""), "url": item.get("url", ""), "snippet": item.get("description", "") }) return results if results else [{"title": "No results", "url": "", "snippet": ""}] except Exception as e: return [{"title": "Error", "url": "", "snippet": str(e)}] def call_llm(messages: List[Dict], temperature: float = 0.6, max_tokens: int = 1024) -> str: """Call Fireworks AI API""" api_key = get_secret("FIREWORKS_API_KEY", "") if not api_key: return "FIREWORKS_API_KEY missing. Set it in Secrets or sidebar." url = "https://api.fireworks.ai/inference/v1/chat/completions" payload = { "model": "accounts/fireworks/models/llama-v3p1-70b-instruct", "messages": messages, "max_tokens": max_tokens, "temperature": temperature, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0 } headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" } try: r = requests.post(url, headers=headers, json=payload, timeout=60) r.raise_for_status() return r.json()["choices"][0]["message"]["content"] except Exception as e: return f"[LLM Error] {e}" def load_file_text(upload) -> str: """Load text from uploaded file""" name = upload.name.lower() try: content = upload.read() text = content.decode("utf-8", errors="ignore") except: return "" # FASTA handling if name.endswith((".fa", ".fasta", ".faa", ".fna")) and BIOPYTHON_AVAILABLE: try: upload.seek(0) records = list(SeqIO.parse(upload, "fasta")) seqs = [f">{r.id}\n{str(r.seq)}" for r in records] return "\n\n".join(seqs) except: pass return text def chunk_text(text: str, size: int = 1200, overlap: int = 200) -> List[str]: """Split text into chunks""" chunks = [] start = 0 text_len = len(text) while start < text_len: end = min(start + size, text_len) chunks.append(text[start:end]) if end >= text_len: break start = end - overlap return chunks def build_index(texts: List[str]): """Build vector index""" if not SENTENCE_TRANSFORMERS_AVAILABLE or not FAISS_AVAILABLE: return None, None try: model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") embeddings = model.encode(texts, show_progress_bar=False) dim = embeddings.shape[1] index = faiss.IndexFlatIP(dim) index.add(embeddings.astype("float32")) return index, model except Exception as e: st.warning(f"Index build failed: {e}") return None, None def search_index(query: str, index, model, texts: List[str], k: int = 4) -> List[Dict]: """Search vector index""" if index is None or model is None: return [] try: q_emb = model.encode([query]) D, I = index.search(q_emb.astype("float32"), k) results = [] for idx, score in zip(I[0], D[0]): if 0 <= idx < len(texts): results.append({ "score": float(score), "text": texts[idx] }) return results except: return [] def esm2_embed(seq: str, model_name: str = "facebook/esm2_t6_8M_UR50D") -> Dict: """ESM-2 protein embedding""" if not TORCH_AVAILABLE or not TRANSFORMERS_AVAILABLE: return {"error": "PyTorch/Transformers not available"} try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForMaskedLM.from_pretrained(model_name) model.eval() with torch.no_grad(): inputs = tokenizer(seq, return_tensors="pt") outputs = model(**inputs, output_hidden_states=True) hidden = outputs.hidden_states[-1].mean(dim=1).squeeze(0) vec = hidden.numpy() return { "embedding": vec.tolist(), "size": vec.shape[0] } except Exception as e: return {"error": str(e)} def dna_embed(seq: str, model_name: str = "zhihan1996/DNABERT-2-117M") -> Dict: """DNA embedding""" if not TORCH_AVAILABLE or not TRANSFORMERS_AVAILABLE: return {"error": "PyTorch/Transformers not available"} try: tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModel.from_pretrained(model_name, trust_remote_code=True) model.eval() with torch.no_grad(): inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=512) outputs = model(**inputs) hidden = outputs.last_hidden_state.mean(dim=1).squeeze(0) vec = hidden.numpy() return { "embedding": vec.tolist(), "size": vec.shape[0] } except Exception as e: return {"error": str(e)} def build_context(query: str, docs: List[str], index, model, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]: """Build context from sources""" pieces = [] sources = [] # File search if index and model and docs: hits = search_index(query, index, model, docs, k=4) for h in hits: pieces.append(f"[FILE] {h['text'][:500]}") sources.append({"type": "file", "text": h['text'][:100]}) # Web search if use_web: results = brave_search(query, count=web_k) for r in results: pieces.append(f"[WEB] {r['title']}\n{r['snippet']}") sources.append({"type": "web", "title": r['title'], "url": r['url']}) context = "\n\n---\n\n".join(pieces)[:4000] return context, sources def answer_question(query: str, context: str) -> str: """Generate answer""" system = ( "You are a bioinformatics assistant. Be concise and factual. " "Never give medical advice. Answer in the user's language." ) user_msg = f"Context:\n{context}\n\nQuestion: {query}" messages = [ {"role": "system", "content": system}, {"role": "user", "content": user_msg} ] return call_llm(messages, temperature=0.4, max_tokens=1000) # --------------- Streamlit UI --------------- st.set_page_config(page_title=APP_TITLE, page_icon="🧬", layout="wide") st.title(APP_TITLE) st.caption(DISCLAIMER) # Session state init if "docs" not in st.session_state: st.session_state.docs = [] if "index" not in st.session_state: st.session_state.index = None if "model" not in st.session_state: st.session_state.model = None # Sidebar with st.sidebar: st.header("Configuration") fw_key = st.text_input( "FIREWORKS_API_KEY", value=get_secret("FIREWORKS_API_KEY", ""), type="password" ) brave_key = st.text_input( "BRAVE_API_KEY", value=get_secret("BRAVE_API_KEY", ""), type="password" ) if fw_key: os.environ["FIREWORKS_API_KEY"] = fw_key if brave_key: os.environ["BRAVE_API_KEY"] = brave_key st.divider() esm_model = st.text_input( "ESM-2 Model", value="facebook/esm2_t6_8M_UR50D" ) dna_model = st.text_input( "DNA Model", value="zhihan1996/DNABERT-2-117M" ) use_web = st.checkbox("Enable web search", value=True) web_results = st.slider("Web results", 1, 10, 3) # Tabs tab1, tab2, tab3, tab4 = st.tabs(["Chat", "Protein", "DNA", "About"]) # File upload with st.expander("📁 Upload Files", expanded=True): files = st.file_uploader( "Upload text/FASTA files", type=["txt", "fa", "fasta", "csv", "json"], accept_multiple_files=True ) if files: docs = [] for f in files: try: text = load_file_text(f) if text: docs.extend(chunk_text(text)) except Exception as e: st.error(f"Error reading {f.name}: {e}") if docs: st.session_state.docs = docs st.success(f"Loaded {len(docs)} chunks") if SENTENCE_TRANSFORMERS_AVAILABLE and FAISS_AVAILABLE: with st.spinner("Building index..."): index, model = build_index(docs) if index: st.session_state.index = index st.session_state.model = model # Chat tab with tab1: st.subheader("💬 Chat Assistant") question = st.text_area( "Ask about proteins, DNA, or bioinformatics:", value="What is the role of ESM-2 embeddings in protein analysis?", height=100 ) if st.button("Get Answer", type="primary"): if not get_secret("FIREWORKS_API_KEY"): st.error("Please set FIREWORKS_API_KEY") else: with st.spinner("Thinking..."): context, sources = build_context( question, st.session_state.docs, st.session_state.index, st.session_state.model, use_web, web_results ) answer = answer_question(question, context) st.markdown("### Answer") st.write(answer) if sources: st.markdown("### Sources") for s in sources: if s["type"] == "web": st.write(f"- 🌐 [{s['title']}]({s['url']})") elif s["type"] == "file": st.write(f"- 📄 File: {s['text'][:80]}...") # Protein tab with tab2: st.subheader("🧬 Protein Analysis") protein_seq = st.text_area( "Enter protein sequence:", value="MKTIIALSYIFCLVFA", height=100 ) col1, col2 = st.columns(2) with col1: if st.button("Analyze Protein"): seq = protein_seq.strip().upper() # Basic stats st.write(f"**Length:** {len(seq)}") st.write(f"**Unique AAs:** {len(set(seq))}") # ESM-2 embedding if TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE: with st.spinner("Computing embedding..."): result = esm2_embed(seq, esm_model) if "error" in result: st.error(result["error"]) else: st.success(f"Embedding size: {result['size']}") st.json({"preview": result["embedding"][:5]}) else: st.warning("PyTorch not available for embeddings") with col2: st.info("Amino acid composition and structure prediction features coming soon") # DNA tab with tab3: st.subheader("🧬 DNA Analysis") dna_seq = st.text_area( "Enter DNA sequence:", value="ATGCGATCGTAGC", height=100 ) col1, col2 = st.columns(2) with col1: if st.button("Analyze DNA"): seq = dna_seq.strip().upper() # GC content gc = (seq.count("G") + seq.count("C")) / len(seq) if seq else 0 st.write(f"**Length:** {len(seq)}") st.write(f"**GC Content:** {gc:.2%}") # DNA embedding if TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE: with st.spinner("Computing embedding..."): result = dna_embed(seq, dna_model) if "error" in result: st.error(result["error"]) else: st.success(f"Embedding size: {result['size']}") st.json({"preview": result["embedding"][:5]}) else: st.warning("PyTorch not available for embeddings") with col2: st.info("Motif analysis and structure prediction coming soon") # About tab with tab4: st.subheader("ℹ️ About") st.markdown(""" ### Features - 💬 RAG-based chat for bioinformatics questions - 🧬 Protein sequence analysis with ESM-2 - 🧬 DNA sequence analysis with DNABERT-2 - 🔍 Web search integration via Brave API - 📁 File upload and vector search ### Models - **Proteins:** ESM-2 (Facebook) - **DNA:** DNABERT-2 (Microsoft) - **LLM:** Llama 3.1 70B (via Fireworks) ### Disclaimer This tool is for research and educational purposes only. Not for medical diagnosis or treatment decisions. """) # Dependency check st.divider() st.subheader("System Status") deps = { "PyTorch": TORCH_AVAILABLE, "Transformers": TRANSFORMERS_AVAILABLE, "Sentence Transformers": SENTENCE_TRANSFORMERS_AVAILABLE, "FAISS": FAISS_AVAILABLE, "BioPython": BIOPYTHON_AVAILABLE, "Datasets": DATASETS_AVAILABLE } for name, available in deps.items(): if available: st.success(f"✅ {name}") else: st.warning(f"⚠️ {name} not available")