BIOseq / app.py
openfree's picture
Update app.py
ba0297e verified
raw
history blame
16 kB
import os
import json
from typing import List, Dict, Tuple
import streamlit as st
import requests
# 선택적 μ˜μ‘΄μ„± κ°€λ“œ
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
print("[WARNING] torch not available")
try:
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
print("[WARNING] transformers not available")
try:
from datasets import load_dataset
DATASETS_AVAILABLE = True
except ImportError:
DATASETS_AVAILABLE = False
print("[WARNING] datasets not available")
try:
from sentence_transformers import SentenceTransformer
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
SENTENCE_TRANSFORMERS_AVAILABLE = False
print("[WARNING] sentence_transformers not available")
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
print("[WARNING] faiss not available")
try:
from Bio import SeqIO
BIOPYTHON_AVAILABLE = True
except ImportError:
BIOPYTHON_AVAILABLE = False
print("[WARNING] biopython not available")
# μƒμˆ˜
APP_TITLE = "BioSeq Chat: Protein & DNA Assistant"
DISCLAIMER = "This tool is for research/education and is not a medical device. Do not use outputs for diagnosis or treatment decisions."
# --------------- Helper Functions ---------------
def get_secret(name: str, fallback: str = "") -> str:
"""Get secret from st.secrets or environment"""
try:
# Streamlit secrets
if hasattr(st, 'secrets') and name in st.secrets:
return st.secrets[name]
except:
pass
# Environment variable
return os.environ.get(name, fallback)
def brave_search(query: str, count: int = 5) -> List[Dict]:
"""Brave Search API"""
key = get_secret("BRAVE_API_KEY", "")
if not key:
return [{
"title": "BRAVE_API_KEY missing",
"url": "",
"snippet": "Set BRAVE_API_KEY in Space secrets or sidebar"
}]
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"Accept": "application/json",
"X-Subscription-Token": key
}
params = {"q": query, "count": count}
try:
r = requests.get(url, headers=headers, params=params, timeout=15)
r.raise_for_status()
data = r.json()
results = []
for item in data.get("web", {}).get("results", [])[:count]:
results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("description", "")
})
return results if results else [{"title": "No results", "url": "", "snippet": ""}]
except Exception as e:
return [{"title": "Error", "url": "", "snippet": str(e)}]
def call_llm(messages: List[Dict], temperature: float = 0.6, max_tokens: int = 1024) -> str:
"""Call Fireworks AI API"""
api_key = get_secret("FIREWORKS_API_KEY", "")
if not api_key:
return "FIREWORKS_API_KEY missing. Set it in Secrets or sidebar."
url = "https://api.fireworks.ai/inference/v1/chat/completions"
payload = {
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
try:
r = requests.post(url, headers=headers, json=payload, timeout=60)
r.raise_for_status()
return r.json()["choices"][0]["message"]["content"]
except Exception as e:
return f"[LLM Error] {e}"
def load_file_text(upload) -> str:
"""Load text from uploaded file"""
name = upload.name.lower()
try:
content = upload.read()
text = content.decode("utf-8", errors="ignore")
except:
return ""
# FASTA handling
if name.endswith((".fa", ".fasta", ".faa", ".fna")) and BIOPYTHON_AVAILABLE:
try:
upload.seek(0)
records = list(SeqIO.parse(upload, "fasta"))
seqs = [f">{r.id}\n{str(r.seq)}" for r in records]
return "\n\n".join(seqs)
except:
pass
return text
def chunk_text(text: str, size: int = 1200, overlap: int = 200) -> List[str]:
"""Split text into chunks"""
chunks = []
start = 0
text_len = len(text)
while start < text_len:
end = min(start + size, text_len)
chunks.append(text[start:end])
if end >= text_len:
break
start = end - overlap
return chunks
def build_index(texts: List[str]):
"""Build vector index"""
if not SENTENCE_TRANSFORMERS_AVAILABLE or not FAISS_AVAILABLE:
return None, None
try:
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embeddings = model.encode(texts, show_progress_bar=False)
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings.astype("float32"))
return index, model
except Exception as e:
st.warning(f"Index build failed: {e}")
return None, None
def search_index(query: str, index, model, texts: List[str], k: int = 4) -> List[Dict]:
"""Search vector index"""
if index is None or model is None:
return []
try:
q_emb = model.encode([query])
D, I = index.search(q_emb.astype("float32"), k)
results = []
for idx, score in zip(I[0], D[0]):
if 0 <= idx < len(texts):
results.append({
"score": float(score),
"text": texts[idx]
})
return results
except:
return []
def esm2_embed(seq: str, model_name: str = "facebook/esm2_t6_8M_UR50D") -> Dict:
"""ESM-2 protein embedding"""
if not TORCH_AVAILABLE or not TRANSFORMERS_AVAILABLE:
return {"error": "PyTorch/Transformers not available"}
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()
with torch.no_grad():
inputs = tokenizer(seq, return_tensors="pt")
outputs = model(**inputs, output_hidden_states=True)
hidden = outputs.hidden_states[-1].mean(dim=1).squeeze(0)
vec = hidden.numpy()
return {
"embedding": vec.tolist(),
"size": vec.shape[0]
}
except Exception as e:
return {"error": str(e)}
def dna_embed(seq: str, model_name: str = "zhihan1996/DNABERT-2-117M") -> Dict:
"""DNA embedding"""
if not TORCH_AVAILABLE or not TRANSFORMERS_AVAILABLE:
return {"error": "PyTorch/Transformers not available"}
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
model.eval()
with torch.no_grad():
inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=512)
outputs = model(**inputs)
hidden = outputs.last_hidden_state.mean(dim=1).squeeze(0)
vec = hidden.numpy()
return {
"embedding": vec.tolist(),
"size": vec.shape[0]
}
except Exception as e:
return {"error": str(e)}
def build_context(query: str, docs: List[str], index, model, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]:
"""Build context from sources"""
pieces = []
sources = []
# File search
if index and model and docs:
hits = search_index(query, index, model, docs, k=4)
for h in hits:
pieces.append(f"[FILE] {h['text'][:500]}")
sources.append({"type": "file", "text": h['text'][:100]})
# Web search
if use_web:
results = brave_search(query, count=web_k)
for r in results:
pieces.append(f"[WEB] {r['title']}\n{r['snippet']}")
sources.append({"type": "web", "title": r['title'], "url": r['url']})
context = "\n\n---\n\n".join(pieces)[:4000]
return context, sources
def answer_question(query: str, context: str) -> str:
"""Generate answer"""
system = (
"You are a bioinformatics assistant. Be concise and factual. "
"Never give medical advice. Answer in the user's language."
)
user_msg = f"Context:\n{context}\n\nQuestion: {query}"
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user_msg}
]
return call_llm(messages, temperature=0.4, max_tokens=1000)
# --------------- Streamlit UI ---------------
st.set_page_config(page_title=APP_TITLE, page_icon="🧬", layout="wide")
st.title(APP_TITLE)
st.caption(DISCLAIMER)
# Session state init
if "docs" not in st.session_state:
st.session_state.docs = []
if "index" not in st.session_state:
st.session_state.index = None
if "model" not in st.session_state:
st.session_state.model = None
# Sidebar
with st.sidebar:
st.header("Configuration")
fw_key = st.text_input(
"FIREWORKS_API_KEY",
value=get_secret("FIREWORKS_API_KEY", ""),
type="password"
)
brave_key = st.text_input(
"BRAVE_API_KEY",
value=get_secret("BRAVE_API_KEY", ""),
type="password"
)
if fw_key:
os.environ["FIREWORKS_API_KEY"] = fw_key
if brave_key:
os.environ["BRAVE_API_KEY"] = brave_key
st.divider()
esm_model = st.text_input(
"ESM-2 Model",
value="facebook/esm2_t6_8M_UR50D"
)
dna_model = st.text_input(
"DNA Model",
value="zhihan1996/DNABERT-2-117M"
)
use_web = st.checkbox("Enable web search", value=True)
web_results = st.slider("Web results", 1, 10, 3)
# Tabs
tab1, tab2, tab3, tab4 = st.tabs(["Chat", "Protein", "DNA", "About"])
# File upload
with st.expander("πŸ“ Upload Files", expanded=True):
files = st.file_uploader(
"Upload text/FASTA files",
type=["txt", "fa", "fasta", "csv", "json"],
accept_multiple_files=True
)
if files:
docs = []
for f in files:
try:
text = load_file_text(f)
if text:
docs.extend(chunk_text(text))
except Exception as e:
st.error(f"Error reading {f.name}: {e}")
if docs:
st.session_state.docs = docs
st.success(f"Loaded {len(docs)} chunks")
if SENTENCE_TRANSFORMERS_AVAILABLE and FAISS_AVAILABLE:
with st.spinner("Building index..."):
index, model = build_index(docs)
if index:
st.session_state.index = index
st.session_state.model = model
# Chat tab
with tab1:
st.subheader("πŸ’¬ Chat Assistant")
question = st.text_area(
"Ask about proteins, DNA, or bioinformatics:",
value="What is the role of ESM-2 embeddings in protein analysis?",
height=100
)
if st.button("Get Answer", type="primary"):
if not get_secret("FIREWORKS_API_KEY"):
st.error("Please set FIREWORKS_API_KEY")
else:
with st.spinner("Thinking..."):
context, sources = build_context(
question,
st.session_state.docs,
st.session_state.index,
st.session_state.model,
use_web,
web_results
)
answer = answer_question(question, context)
st.markdown("### Answer")
st.write(answer)
if sources:
st.markdown("### Sources")
for s in sources:
if s["type"] == "web":
st.write(f"- 🌐 [{s['title']}]({s['url']})")
elif s["type"] == "file":
st.write(f"- πŸ“„ File: {s['text'][:80]}...")
# Protein tab
with tab2:
st.subheader("🧬 Protein Analysis")
protein_seq = st.text_area(
"Enter protein sequence:",
value="MKTIIALSYIFCLVFA",
height=100
)
col1, col2 = st.columns(2)
with col1:
if st.button("Analyze Protein"):
seq = protein_seq.strip().upper()
# Basic stats
st.write(f"**Length:** {len(seq)}")
st.write(f"**Unique AAs:** {len(set(seq))}")
# ESM-2 embedding
if TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE:
with st.spinner("Computing embedding..."):
result = esm2_embed(seq, esm_model)
if "error" in result:
st.error(result["error"])
else:
st.success(f"Embedding size: {result['size']}")
st.json({"preview": result["embedding"][:5]})
else:
st.warning("PyTorch not available for embeddings")
with col2:
st.info("Amino acid composition and structure prediction features coming soon")
# DNA tab
with tab3:
st.subheader("🧬 DNA Analysis")
dna_seq = st.text_area(
"Enter DNA sequence:",
value="ATGCGATCGTAGC",
height=100
)
col1, col2 = st.columns(2)
with col1:
if st.button("Analyze DNA"):
seq = dna_seq.strip().upper()
# GC content
gc = (seq.count("G") + seq.count("C")) / len(seq) if seq else 0
st.write(f"**Length:** {len(seq)}")
st.write(f"**GC Content:** {gc:.2%}")
# DNA embedding
if TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE:
with st.spinner("Computing embedding..."):
result = dna_embed(seq, dna_model)
if "error" in result:
st.error(result["error"])
else:
st.success(f"Embedding size: {result['size']}")
st.json({"preview": result["embedding"][:5]})
else:
st.warning("PyTorch not available for embeddings")
with col2:
st.info("Motif analysis and structure prediction coming soon")
# About tab
with tab4:
st.subheader("ℹ️ About")
st.markdown("""
### Features
- πŸ’¬ RAG-based chat for bioinformatics questions
- 🧬 Protein sequence analysis with ESM-2
- 🧬 DNA sequence analysis with DNABERT-2
- πŸ” Web search integration via Brave API
- πŸ“ File upload and vector search
### Models
- **Proteins:** ESM-2 (Facebook)
- **DNA:** DNABERT-2 (Microsoft)
- **LLM:** Llama 3.1 70B (via Fireworks)
### Disclaimer
This tool is for research and educational purposes only.
Not for medical diagnosis or treatment decisions.
""")
# Dependency check
st.divider()
st.subheader("System Status")
deps = {
"PyTorch": TORCH_AVAILABLE,
"Transformers": TRANSFORMERS_AVAILABLE,
"Sentence Transformers": SENTENCE_TRANSFORMERS_AVAILABLE,
"FAISS": FAISS_AVAILABLE,
"BioPython": BIOPYTHON_AVAILABLE,
"Datasets": DATASETS_AVAILABLE
}
for name, available in deps.items():
if available:
st.success(f"βœ… {name}")
else:
st.warning(f"⚠️ {name} not available")