|
|
|
|
|
import os |
|
|
import streamlit as st |
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ.setdefault("HF_HOME", "/tmp/huggingface") |
|
|
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers") |
|
|
os.environ.setdefault("HF_DATASETS_CACHE", "/tmp/huggingface/datasets") |
|
|
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub") |
|
|
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/huggingface") |
|
|
|
|
|
|
|
|
MODEL_ID = "kirubel1738/biogpt-pubmedqa-finetuned" |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
"""Load BioGPT model (on CPU).""" |
|
|
generator = pipeline("text-generation", model=MODEL_ID, device=-1) |
|
|
return generator |
|
|
|
|
|
|
|
|
generator = load_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="BioGPT β PubMedQA demo", layout="centered") |
|
|
st.title("𧬠BioGPT β PubMedQA Demo") |
|
|
|
|
|
st.write("Ask a biomedical question and get an answer generated by BioGPT fine-tuned on PubMedQA.") |
|
|
|
|
|
user_input = st.text_area("Enter your biomedical question:", height=150) |
|
|
|
|
|
if st.button("Get Answer"): |
|
|
if user_input.strip(): |
|
|
with st.spinner("Generating answer..."): |
|
|
try: |
|
|
result = generator( |
|
|
user_input, |
|
|
max_new_tokens=128, |
|
|
do_sample=True, |
|
|
temperature=0.7 |
|
|
) |
|
|
output_text = result[0]["generated_text"] |
|
|
st.success("Answer:") |
|
|
st.write(output_text) |
|
|
except Exception as e: |
|
|
st.error(f"Generation failed: {e}") |
|
|
else: |
|
|
st.warning("Please enter a question.") |
|
|
|
|
|
st.markdown("---") |
|
|
st.caption("Model: kirubel1738/biogpt-pubmedqa-finetuned | Runs on CPU") |
|
|
|