BIOseq / app.py
openfree's picture
Update app.py
29ce347 verified
raw
history blame
18.3 kB
import os
import json
import time
from typing import List, Dict, Tuple
import streamlit as st
import requests
# Guard imports for optional dependencies
try:
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
TORCH_AVAILABLE = True
except Exception:
TORCH_AVAILABLE = False
try:
from datasets import load_dataset
DATASETS_AVAILABLE = True
except Exception:
DATASETS_AVAILABLE = False
try:
from sentence_transformers import SentenceTransformer
SENTENCE_TRANSFORMERS_AVAILABLE = True
except Exception:
SENTENCE_TRANSFORMERS_AVAILABLE = False
try:
import faiss
FAISS_AVAILABLE = True
except Exception:
FAISS_AVAILABLE = False
try:
from Bio import SeqIO
BIOPYTHON_AVAILABLE = True
except Exception:
BIOPYTHON_AVAILABLE = False
# Constants
APP_TITLE = "BioSeq Chat: Protein & DNA Assistant"
DISCLAIMER = (
"This tool is for research/education and is not a medical device. "
"Do not use outputs for diagnosis or treatment decisions."
)
# --------------- Helper Functions ---------------
def get_secret(name: str, fallback: str = "") -> str:
"""Get secret from st.secrets, environment, or fallback"""
try:
if hasattr(st, 'secrets'):
return st.secrets.get(name, os.environ.get(name, fallback))
except:
pass
return os.environ.get(name, fallback)
def brave_search(query: str, count: int = 5) -> List[Dict]:
"""Search using Brave Search API"""
key = get_secret("BRAVE_API_KEY", "")
if not key:
return [{"title": "BRAVE_API_KEY is missing",
"url": "",
"snippet": "Set BRAVE_API_KEY in Space secrets or sidebar to enable web search."}]
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"Accept": "application/json",
"X-Subscription-Token": key,
"Accept-Encoding": "gzip"
}
params = {"q": query, "count": count, "country": "us"}
try:
r = requests.get(url, headers=headers, params=params, timeout=15)
r.raise_for_status()
data = r.json()
results = []
for item in data.get("web", {}).get("results", [])[:count]:
results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("description", ""),
})
return results if results else [{"title": "No results", "url": "", "snippet": "Query returned no results."}]
except Exception as e:
return [{"title": "Search error", "url": "", "snippet": str(e)}]
def call_fireworks(messages: List[Dict], temperature: float = 0.6, max_tokens: int = 1024) -> str:
"""Call Fireworks AI chat completion API"""
api_key = get_secret("FIREWORKS_API_KEY", "")
if not api_key:
return "FIREWORKS_API_KEY is missing. Set it in Secrets or the sidebar."
url = "https://api.fireworks.ai/inference/v1/chat/completions"
payload = {
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"max_tokens": max_tokens,
"top_p": 1,
"top_k": 40,
"presence_penalty": 0,
"frequency_penalty": 0,
"temperature": temperature,
"messages": messages
}
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
try:
r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=60)
r.raise_for_status()
data = r.json()
return data["choices"][0]["message"]["content"]
except Exception as e:
return f"[Fireworks API error] {e}"
def load_text_from_file(upload) -> str:
"""Load text from uploaded file"""
name = upload.name.lower()
content = upload.read()
try:
text = content.decode("utf-8", errors="ignore")
except:
text = str(content)
# FASTA file handling
if name.endswith((".fa", ".fasta", ".faa", ".fna")) and BIOPYTHON_AVAILABLE:
upload.seek(0)
try:
records = list(SeqIO.parse(upload, "fasta"))
seqs = []
for r in records:
seqs.append(f">{r.id}\n{str(r.seq)}")
return "\n\n".join(seqs)
except:
pass
return text
def build_vector_index(texts: List[str], embedder_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
"""Build FAISS vector index from texts"""
if not SENTENCE_TRANSFORMERS_AVAILABLE or not FAISS_AVAILABLE:
return None, None, None
try:
model = SentenceTransformer(embedder_name)
emb = model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
dim = emb.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(emb.astype("float32"))
return index, emb, model
except Exception as e:
st.warning(f"Failed to build index: {e}")
return None, None, None
def search_index(query: str, index, model, texts: List[str], k: int = 4):
"""Search vector index"""
if index is None or model is None:
return []
try:
q = model.encode([query], normalize_embeddings=True)
D, I = index.search(q.astype("float32"), k)
hits = []
for idx, score in zip(I[0], D[0]):
if 0 <= idx < len(texts):
hits.append({"score": float(score), "text": texts[idx]})
return hits
except:
return []
def esm2_embed(seq: str, model_id: str = "facebook/esm2_t6_8M_UR50D") -> Dict:
"""Generate ESM-2 embedding for protein sequence"""
if not TORCH_AVAILABLE:
return {"error": "Transformers/torch not available. Please wait for dependencies to install."}
try:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_id, trust_remote_code=True)
model.eval()
with torch.no_grad():
toks = tokenizer(seq, return_tensors="pt")
out = model(**toks, output_hidden_states=True)
hidden = out.hidden_states[-1].mean(dim=1).squeeze(0)
vec = hidden.detach().cpu().numpy()
return {"embedding": vec.tolist(), "hidden_size": vec.shape[0]}
except Exception as e:
return {"error": str(e)}
def dna_embed(seq: str, model_id: str = "zhihan1996/DNABERT-2-117M") -> Dict:
"""Generate DNA embedding"""
if not TORCH_AVAILABLE:
return {"error": "Transformers/torch not available. Please wait for dependencies to install."}
try:
from transformers import AutoTokenizer, AutoModel
import torch
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
model.eval()
with torch.no_grad():
toks = tokenizer(seq, return_tensors="pt", truncation=True, max_length=4096)
out = model(**toks, output_hidden_states=True)
hidden = out.last_hidden_state.mean(dim=1).squeeze(0)
vec = hidden.detach().cpu().numpy()
return {"embedding": vec.tolist(), "hidden_size": vec.shape[0]}
except Exception as e:
return {"error": str(e)}
def chunk_text(text: str, chunk_size: int = 1200, overlap: int = 200) -> List[str]:
"""Chunk text with overlap"""
text = text.replace("\r\n", "\n")
chunks = []
start = 0
while start < len(text):
end = min(len(text), start + chunk_size)
chunks.append(text[start:end])
if end >= len(text):
break
start = end - overlap
return chunks
def build_context(user_query: str, index, index_model, docs: List[str], loaded_datasets: List, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]:
"""Build context from various sources"""
pieces = []
sources = []
# From uploaded files
if index is not None and index_model is not None and docs:
hits = search_index(user_query, index, index_model, docs, k=4)
for h in hits:
pieces.append(f"[FILE] {h['text'][:800]}")
sources.append({"type": "file", "text": h["text"][:200]})
# From datasets
for rid, sample in loaded_datasets:
if sample:
pieces.append(f"[DATASET {rid}] {sample}")
sources.append({"type": "dataset", "id": rid})
# From web
if use_web:
results = brave_search(user_query, count=web_k)
for r in results:
snippet = r.get("snippet", "")
url = r.get("url", "")
title = r.get("title", "")
pieces.append(f"[WEB] {title}\n{snippet}\n{url}")
sources.append({"type": "web", "title": title, "url": url})
context = "\n\n---\n\n".join(pieces)[:6000]
return context, sources
def chat_answer(user_query: str, index, index_model, docs: List[str], loaded_datasets: List, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]:
"""Generate chat answer with context"""
context, sources = build_context(user_query, index, index_model, docs, loaded_datasets, use_web, web_k)
system = (
"You are a concise, careful bioinformatics assistant for protein and DNA. "
"Answer with factual, verifiable statements. "
"When uncertain, say so briefly. "
"Never give medical advice. Provide short references as plain URLs or titles if present in context. "
"User uploads and web/dataset snippets are provided as context below."
)
prompt = f"Context:\n{context}\n\nUser question:\n{user_query}\n\nAnswer in Korean if the user used Korean; otherwise match user language."
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt}
]
answer = call_fireworks(messages, temperature=0.4, max_tokens=1200)
return answer, sources
# --------------- Streamlit UI ---------------
st.set_page_config(page_title=APP_TITLE, page_icon="๐Ÿงฌ", layout="wide")
st.title(APP_TITLE)
st.caption(DISCLAIMER)
# Check dependencies status
if not TORCH_AVAILABLE:
st.warning("โณ PyTorch is being installed. Some features may be unavailable initially. Please refresh in a minute.")
# Initialize session state
if 'docs' not in st.session_state:
st.session_state.docs = []
if 'index' not in st.session_state:
st.session_state.index = None
if 'index_model' not in st.session_state:
st.session_state.index_model = None
if 'loaded_datasets' not in st.session_state:
st.session_state.loaded_datasets = []
# Sidebar configuration
with st.sidebar:
st.header("Keys and settings")
fw_key = st.text_input("FIREWORKS_API_KEY", value=get_secret("FIREWORKS_API_KEY", ""), type="password")
brave_key = st.text_input("BRAVE_API_KEY", value=get_secret("BRAVE_API_KEY", ""), type="password")
if fw_key:
os.environ["FIREWORKS_API_KEY"] = fw_key
if brave_key:
os.environ["BRAVE_API_KEY"] = brave_key
st.markdown("### Model selections")
esm2_id = st.text_input(
"Protein model (ESM-2)",
value="facebook/esm2_t6_8M_UR50D",
help="Try larger models like facebook/esm2_t33_650M_UR50D if resources allow."
)
dna_id = st.text_input(
"DNA model",
value="zhihan1996/DNABERT-2-117M",
help="Alternative: InstaDeepAI/nucleotide-transformer-500m-human-ref"
)
use_web = st.checkbox("Use Brave web search for context", value=True)
web_k = st.slider("Web results", 1, 10, 4)
st.markdown("### Datasets (optional)")
dataset_ids = st.text_area(
"Datasets to load (one per line)",
value="",
help="Enter Hugging Face dataset repo ids, e.g., 'genomics-benchmark/jaspar_motifs'"
)
st.divider()
st.markdown("Files you upload are indexed locally and used for answers.")
# Main tabs
tabs = st.tabs(["Chat", "Protein", "DNA", "Examples", "About"])
# File upload section
with st.expander("Upload files for context (txt/csv/json/fasta/vcf)", expanded=True):
uploads = st.file_uploader(
"Add files",
type=["txt", "md", "csv", "tsv", "json", "fa", "fasta", "faa", "fna", "vcf"],
accept_multiple_files=True,
key="file_uploader"
)
if uploads:
docs = []
for up in uploads:
try:
txt = load_text_from_file(up)
docs.extend(chunk_text(txt))
except Exception as e:
st.warning(f"Failed to read {up.name}: {e}")
st.session_state.docs = docs
st.caption(f"Indexed chunks: {len(docs)}")
# Build index if docs available
if docs and SENTENCE_TRANSFORMERS_AVAILABLE and FAISS_AVAILABLE:
with st.spinner("Building vector index..."):
index, emb, index_model = build_vector_index(docs)
st.session_state.index = index
st.session_state.index_model = index_model
else:
st.caption("No files uploaded yet")
# Load datasets if specified
if dataset_ids.strip() and DATASETS_AVAILABLE:
dataset_list = [x.strip() for x in dataset_ids.splitlines() if x.strip()]
if dataset_list != [d[0] for d in st.session_state.loaded_datasets]:
st.session_state.loaded_datasets = []
for rid in dataset_list:
with st.spinner(f"Loading dataset {rid}..."):
try:
ds = load_dataset(rid)
sample = ""
for split in ds.keys():
try:
row = ds[split][0]
sample = json.dumps(row, ensure_ascii=False)[:500]
break
except:
pass
st.session_state.loaded_datasets.append((rid, sample))
st.success(f"Loaded {rid}")
except Exception as e:
st.error(f"Failed to load {rid}: {e}")
# Chat tab
with tabs[0]:
st.subheader("Chat")
q = st.text_area("Ask a question about protein/DNA", value="ESM-2 ์ž„๋ฒ ๋”ฉ์€ ๋‹จ๋ฐฑ์งˆ ๊ธฐ๋Šฅ ํ•ด์„์— ์–ด๋–ป๊ฒŒ ๋„์›€๋˜๋‚˜์š”?")
if st.button("Answer", type="primary"):
with st.spinner("Thinking..."):
ans, srcs = chat_answer(
q,
st.session_state.index,
st.session_state.index_model,
st.session_state.docs,
st.session_state.loaded_datasets,
use_web,
web_k
)
st.write(ans)
if srcs:
st.markdown("#### Sources")
for s in srcs:
if s.get("type") == "web" and s.get("url"):
st.markdown(f"- {s.get('title', 'web')}: {s.get('url')}")
elif s.get("type") == "dataset":
st.markdown(f"- dataset: {s.get('id')}")
elif s.get("type") == "file":
snippet = s.get("text", "")
st.markdown(f"- file snippet: {snippet[:120]}...")
# Protein tab
with tabs[1]:
st.subheader("Protein analysis")
seq = st.text_area("Protein sequence (amino acids only)", value="MKTIIALSYIFCLVFADYKDDDDK")
col1, col2 = st.columns(2)
with col1:
st.caption("ESM-2 embedding")
if st.button("Run ESM-2", key="run_esm2"):
with st.spinner("Computing ESM-2 embedding..."):
out = esm2_embed(seq.strip(), esm2_id)
if "error" in out:
st.error(out["error"])
else:
st.success(f"Vector size: {out['hidden_size']}")
st.json({"embedding_preview": out["embedding"][:8]})
with col2:
st.caption("Quick stats")
s = seq.replace("\n", "").replace(" ", "").upper()
length = len(s)
aa_set = sorted(set(list(s)))
st.write(f"Length: {length}")
st.write(f"Unique AAs: {''.join(aa_set)[:30]}")
# DNA tab
with tabs[2]:
st.subheader("DNA analysis")
dseq = st.text_area("DNA sequence (ACGT only)", value="ATGCGTACGTAGCTAGCTAGCTAGGCTAGC")
col3, col4 = st.columns(2)
with col3:
st.caption("DNA embedding")
if st.button("Run DNA embed", key="run_dna"):
with st.spinner("Computing DNA embedding..."):
out = dna_embed(dseq.strip(), dna_id)
if "error" in out:
st.error(out["error"])
else:
st.success(f"Vector size: {out['hidden_size']}")
st.json({"embedding_preview": out["embedding"][:8]}")
with col4:
st.caption("GC content")
s = dseq.upper().replace("N", "").replace(" ", "").replace("\n", "")
if len(s) > 0:
gc = (s.count("G") + s.count("C")) / len(s)
else:
gc = 0
st.write(f"Length: {len(s)}")
st.write(f"GC: {gc:.3f}")
# Examples tab
with tabs[3]:
st.subheader("Examples")
st.markdown("### Example questions you can ask:")
st.markdown("- ์—…๋กœ๋“œํ•œ FASTA์—์„œ ํŠน์ • ๋‹จ๋ฐฑ์งˆ์˜ ๊ธฐ๋Šฅ ์š”์•ฝ๊ณผ ๋ณ€์ด ์˜ํ–ฅ ์งˆ๋ฌธ")
st.markdown("- DNA ์„œ์—ด์—์„œ ํ”„๋กœ๋ชจํ„ฐ ๊ฐ€๋Šฅ์„ฑ๊ณผ ์ „์‚ฌ์ธ์ž ๋ชจํ‹ฐํ”„ ๊ด€๋ จ ๊ทผ๊ฑฐ ์š”์ฒญ")
st.markdown("- Enzyme active site ๊ทผ์ ‘ ๋ณ€์ด์˜ ๋ฆฌ์Šคํฌ ํ•ด์„ (์—ฐ๊ตฌ ๊ด€์ )")
st.markdown("- ENCODE/UniProt/AlphaFold ๊ฐœ๋… ์„ค๋ช… ์š”์ฒญ")
st.markdown("- RAG ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฌธ์„œ ์ธ์šฉ๊ณผ ํ•จ๊ป˜ ๊ฐ„๋žต ๋‹ต๋ณ€ ์š”์ฒญ")
# About tab
with tabs[4]:
st.subheader("About this Space")
st.write("**Models suggested:**")
st.write("- ESM-2 for proteins")
st.write("- DNABERT-2 or Nucleotide Transformer for DNA")
st.write("")
st.write("**Common datasets:**")
st.write("- UniProtKB, AlphaFoldDB, ENCODE, JASPAR, ClinVar")
st.write("")
st.write("**Features:**")
st.write("- Web search powered by Brave Search API")
st.write("- LLM powered by Fireworks AI")
st.write("- Vector search with FAISS")
st.write("")
st.info(DISCLAIMER)