biogpt-pubmedqa-chatbot / src /streamlit_app.py
kirubel1738's picture
Update src/streamlit_app.py
dce22b5 verified
# streamlit_app.py
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
# -----------------------------
# Ensure cache dirs are writable in Spaces
# -----------------------------
os.environ.setdefault("HF_HOME", "/tmp/huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers")
os.environ.setdefault("HF_DATASETS_CACHE", "/tmp/huggingface/datasets")
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub")
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/huggingface")
# Base and adapter model IDs
BASE_MODEL = "microsoft/BioGPT-Large-PubMedQA"
ADAPTER_MODEL = "kirubel1738/biogpt-pubmedqa-finetuned"
@st.cache_resource
def load_model():
"""Load BioGPT with your Biology-QA adapter."""
# Pick device automatically
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL) # apply adapter
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device
)
return generator
# Load once
generator = load_model()
# -----------------------------
# Streamlit UI
# -----------------------------
st.set_page_config(page_title="BioGPT β€” Pubmed Demo", layout="centered")
st.title("🧬 BioGPT β€” Pubmed chatbot")
st.write("Ask a biology-related question and get an answer.")
st.write("Generated by BioGPT-Large-PubMedQA fine-tuned on cais/mmlu and allenai/sciq datasets.")
user_input = st.text_area("Enter your biology question:", height=150)
if st.button("Get Answer"):
if user_input.strip():
with st.spinner("Generating answer..."):
try:
result = generator(
user_input,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
top_p=0.9
)
output_text = result[0]["generated_text"]
st.success("Answer:")
st.write(output_text)
except Exception as e:
st.error(f"Generation failed: {e}")
else:
st.warning("Please enter a question.")
st.markdown("---")
st.caption("Model: microsoft/BioGPT-Large-PubMedQA + adapter kirubel1738/biogpt-pubmedqa-finetuned | Runs on CPU automatically")