|
|
|
|
|
import os |
|
|
import streamlit as st |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ.setdefault("HF_HOME", "/tmp/huggingface") |
|
|
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers") |
|
|
os.environ.setdefault("HF_DATASETS_CACHE", "/tmp/huggingface/datasets") |
|
|
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub") |
|
|
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/huggingface") |
|
|
|
|
|
|
|
|
BASE_MODEL = "microsoft/BioGPT-Large-PubMedQA" |
|
|
ADAPTER_MODEL = "kirubel1738/biogpt-pubmedqa-finetuned" |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
"""Load BioGPT with your Biology-QA adapter.""" |
|
|
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
|
|
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL) |
|
|
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL) |
|
|
|
|
|
generator = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device=device |
|
|
) |
|
|
return generator |
|
|
|
|
|
|
|
|
generator = load_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="BioGPT β Pubmed Demo", layout="centered") |
|
|
st.title("𧬠BioGPT β Pubmed chatbot") |
|
|
|
|
|
st.write("Ask a biology-related question and get an answer.") |
|
|
st.write("Generated by BioGPT-Large-PubMedQA fine-tuned on cais/mmlu and allenai/sciq datasets.") |
|
|
|
|
|
user_input = st.text_area("Enter your biology question:", height=150) |
|
|
|
|
|
if st.button("Get Answer"): |
|
|
if user_input.strip(): |
|
|
with st.spinner("Generating answer..."): |
|
|
try: |
|
|
result = generator( |
|
|
user_input, |
|
|
max_new_tokens=128, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9 |
|
|
) |
|
|
output_text = result[0]["generated_text"] |
|
|
st.success("Answer:") |
|
|
st.write(output_text) |
|
|
except Exception as e: |
|
|
st.error(f"Generation failed: {e}") |
|
|
else: |
|
|
st.warning("Please enter a question.") |
|
|
|
|
|
st.markdown("---") |
|
|
st.caption("Model: microsoft/BioGPT-Large-PubMedQA + adapter kirubel1738/biogpt-pubmedqa-finetuned | Runs on CPU automatically") |
|
|
|