kirubel1738 commited on
Commit
c9d2fa0
·
verified ·
1 Parent(s): a7c3e81

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +144 -98
src/streamlit_app.py CHANGED
@@ -1,111 +1,157 @@
1
- import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
  import os
5
- import shutil
6
-
7
- # Define the custom cache directory for Hugging Face models
8
- cache_dir = "/tmp/biogpt_app_cache"
9
-
10
- # --- PROACTIVE CACHE CLEARING ---
11
- # Set environment variables to point Hugging Face and Streamlit to our custom cache directory
12
- # This is done to prevent PermissionErrors in read-only environments.
13
- os.environ["STREAMLIT_CACHE_DIR"] = "/tmp/streamlit_cache"
14
- os.environ["HF_HOME"] = cache_dir
15
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
16
- os.environ["XDG_CACHE_HOME"] = cache_dir
17
- os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false"
18
-
19
- # Clear the cache directory before attempting to download the model.
20
- if os.path.exists(cache_dir):
21
  try:
22
- st.info("Clearing old cache to ensure a fresh download...")
23
- shutil.rmtree(cache_dir)
24
- except Exception as e:
25
- st.error(f"Failed to clear old cache. Please check directory permissions. Error: {e}")
26
- st.stop()
27
 
28
- # Ensure the new cache directory exists before model loading
29
- try:
30
- os.makedirs(cache_dir, exist_ok=True)
31
- except Exception as e:
32
- st.error(f"Failed to create cache directory at {cache_dir}. Error: {e}")
33
- st.stop()
 
 
34
 
35
- st.set_page_config(page_title="BioGPT-PubMedQA Chatbot", layout="centered")
36
- st.title("🧬 BioGPT-PubMedQA Chatbot")
37
- st.write("A fine-tuned BioGPT model for biomedical Q&A.")
38
 
39
- # Detect device (CPU or GPU)
40
- device = "cuda" if torch.cuda.is_available() else "cpu"
41
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
42
 
43
- # Load model once using Streamlit's resource caching
44
- @st.cache_resource
45
- def load_model(cache_directory):
46
  """
47
- Loads the tokenizer and model from Hugging Face Hub,
48
- explicitly using the specified cache directory.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  """
50
- model_name = "kirubel1738/biogpt-pubmedqa-finetuned"
51
-
 
 
 
 
 
52
  try:
53
- tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_directory)
54
- model = AutoModelForCausalLM.from_pretrained(
55
- model_name,
56
- torch_dtype=dtype,
57
- cache_dir=cache_directory
58
- ).to(device)
59
- return tokenizer, model
60
  except Exception as e:
61
- st.error("Failed to load model. Please ensure the model name is correct and it is publicly accessible.")
62
- st.exception(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  st.stop()
64
 
65
- # Load the model, passing the cache directory
66
- try:
67
- tokenizer, model = load_model(cache_dir)
68
- except Exception as e:
69
- st.error(f"An unexpected error occurred during model loading: {e}")
70
- st.stop()
71
-
72
- # Maintain chat history
73
- if "messages" not in st.session_state:
74
- st.session_state["messages"] = []
75
-
76
- # Display chat history
77
- for msg in st.session_state["messages"]:
78
- with st.chat_message(msg["role"]):
79
- st.markdown(msg["content"])
80
-
81
- # Input box for user
82
- if prompt := st.chat_input("Ask me a biomedical question..."):
83
- st.session_state["messages"].append({"role": "user", "content": prompt})
84
-
85
- with st.chat_message("user"):
86
- st.markdown(prompt)
87
-
88
- formatted_prompt = f"""### Question:{prompt}### Answer:"""
89
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
90
-
91
- with st.spinner("Thinking..."):
92
- with torch.no_grad():
93
- outputs = model.generate(
94
- **inputs,
95
- max_new_tokens=200,
96
- do_sample=True,
97
- temperature=0.7,
98
- top_p=0.9,
99
- eos_token_id=tokenizer.eos_token_id,
100
- )
101
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
102
-
103
- if "### Answer:" in decoded:
104
- answer = decoded.split("### Answer:")[-1].strip()
105
  else:
106
- answer = decoded.strip()
107
-
108
- st.session_state["messages"].append({"role": "assistant", "content": answer})
109
-
110
- with st.chat_message("assistant"):
111
- st.markdown(answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # streamlit_app.py
 
3
  import os
4
+ import json
5
+ import time
6
+
7
+ # -----------------------------
8
+ # IMPORTANT: set cache dirs BEFORE importing transformers/huggingface_hub
9
+ # -----------------------------
10
+ os.environ.setdefault("HF_HOME", os.environ.get("HF_HOME", "/tmp/huggingface"))
11
+ os.environ.setdefault("TRANSFORMERS_CACHE", os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers"))
12
+ os.environ.setdefault("HF_DATASETS_CACHE", os.environ.get("HF_DATASETS_CACHE", "/tmp/huggingface/datasets"))
13
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", os.environ.get("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub"))
14
+ os.environ.setdefault("XDG_CACHE_HOME", os.environ.get("XDG_CACHE_HOME", "/tmp/huggingface"))
15
+ os.environ.setdefault("HOME", os.environ.get("HOME", "/tmp"))
16
+
17
+ # create cache dirs (best-effort)
18
+ for d in [os.environ["HF_HOME"], os.environ["TRANSFORMERS_CACHE"], os.environ["HF_DATASETS_CACHE"], os.environ["HUGGINGFACE_HUB_CACHE"]]:
 
19
  try:
20
+ os.makedirs(d, exist_ok=True)
21
+ os.chmod(d, 0o777)
22
+ except Exception:
23
+ pass
 
24
 
25
+ import streamlit as st
26
+ import requests
27
+
28
+ # Optional heavy imports will be inside local-model branch
29
+ LOCAL_MODE = os.environ.get("USE_LOCAL_MODEL", "0") == "1"
30
+
31
+ # default model id the user provided; keep as-is
32
+ DEFAULT_MODEL_ID = "kirubel1738/biogpt-pubmedqa-finetuned"
33
 
34
+ st.set_page_config(page_title="BioGPT (PubMedQA) demo", layout="centered")
 
 
35
 
36
+ st.title("BioGPT PubMedQA demo")
37
+ st.caption("Defaults to the Hugging Face Inference API (recommended for Spaces / CPU).")
 
38
 
39
+ st.markdown(
 
 
40
  """
41
+ **How it works**
42
+ - By default the app will call Hugging Face's Inference API for the model you specify (fast and avoids memory issues).
43
+ - If you set `USE_LOCAL_MODEL=1` in your environment, the app will attempt to load the model locally using `transformers` (only for GPUs/large memory machines).
44
+ """
45
+ )
46
+
47
+ col1, col2 = st.columns([3,1])
48
+
49
+ with col1:
50
+ model_id = st.text_input("Model repo id", value=DEFAULT_MODEL_ID, help="Hugging Face repo id (e.g. username/modelname).")
51
+ prompt = st.text_area("Question / prompt", height=180, placeholder="Enter a PubMed-style question or prompt...")
52
+ with col2:
53
+ max_new_tokens = st.slider("Max new tokens", 16, 1024, 128)
54
+ temperature = st.slider("Temperature", 0.0, 1.5, 0.0, step=0.05)
55
+ method = st.radio("Run method", ("Inference API (recommended)", "Local model (heavy)"), index=0)
56
+
57
+ # override radio if user set USE_LOCAL_MODEL env var
58
+ if LOCAL_MODE:
59
+ method = "Local model (heavy)"
60
+
61
+ hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_API_TOKEN")
62
+
63
+ def call_inference_api(model_id: str, prompt: str, max_new_tokens: int, temperature: float):
64
+ """
65
+ Simple POST to Hugging Face Inference API.
66
+ If you want to use the InferenceClient from huggingface_hub you can swap this.
67
  """
68
+ api_url = f"https://api-inference.huggingface.co/models/{model_id}"
69
+ headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
70
+ payload = {
71
+ "inputs": prompt,
72
+ "parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature},
73
+ "options": {"wait_for_model": True}
74
+ }
75
  try:
76
+ r = requests.post(api_url, headers=headers, json=payload, timeout=120)
 
 
 
 
 
 
77
  except Exception as e:
78
+ return False, f"Request failed: {e}"
79
+ if r.status_code != 200:
80
+ try:
81
+ error = r.json()
82
+ except Exception:
83
+ error = r.text
84
+ return False, f"API error ({r.status_code}): {error}"
85
+ try:
86
+ resp = r.json()
87
+ # handle several possible response schemas
88
+ if isinstance(resp, dict) and "error" in resp:
89
+ return False, resp["error"]
90
+ # often it's a list of dicts with 'generated_text'
91
+ if isinstance(resp, list):
92
+ out_texts = []
93
+ for item in resp:
94
+ if isinstance(item, dict):
95
+ # common key: 'generated_text'
96
+ for k in ("generated_text", "text", "content"):
97
+ if k in item:
98
+ out_texts.append(item[k])
99
+ break
100
+ else:
101
+ out_texts.append(json.dumps(item))
102
+ else:
103
+ out_texts.append(str(item))
104
+ return True, "\n\n".join(out_texts)
105
+ # fallback
106
+ return True, str(resp)
107
+ except Exception as e:
108
+ return False, f"Could not parse response: {e}"
109
+
110
+ # Local model loader (only if method chosen)
111
+ generator = None
112
+ if method.startswith("Local"):
113
+ st.warning("Local model mode selected — this requires transformers + torch and lots of RAM/GPU. Only use if you know the model fits your hardware.")
114
+ try:
115
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
116
+ import torch
117
+ device = 0 if torch.cuda.is_available() else -1
118
+ st.info(f"torch.cuda.is_available={torch.cuda.is_available()} -- device set to {device}")
119
+ with st.spinner("Loading tokenizer & model (this can take a while)..."):
120
+ tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=os.environ.get("TRANSFORMERS_CACHE"))
121
+ model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=os.environ.get("TRANSFORMERS_CACHE"), low_cpu_mem_usage=True)
122
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
123
+ except Exception as e:
124
+ st.error(f"Local model load failed: {e}")
125
  st.stop()
126
 
127
+ if st.button("Generate"):
128
+ if not prompt or prompt.strip() == "":
129
+ st.error("Please enter a prompt.")
130
+ st.stop()
131
+
132
+ if method.startswith("Inference"):
133
+ if ("kirubel1738/biogpt-pubmedqa-finetuned" in model_id) and not hf_token:
134
+ st.info("If the model is private or rate-limited, set HUGGINGFACE_HUB_TOKEN as a secret in Spaces or as an env var locally.")
135
+ with st.spinner("Querying Hugging Face Inference API..."):
136
+ ok, out = call_inference_api(model_id, prompt, max_new_tokens, float(temperature))
137
+ if not ok:
138
+ st.error(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  else:
140
+ st.success("Done")
141
+ st.text_area("Model output", value=out, height=320)
142
+ else:
143
+ # local model generation
144
+ try:
145
+ with st.spinner("Running local generation..."):
146
+ results = generator(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature)
147
+ if isinstance(results, list) and len(results) > 0 and "generated_text" in results[0]:
148
+ out = results[0]["generated_text"]
149
+ else:
150
+ out = str(results)
151
+ st.success("Done")
152
+ st.text_area("Model output", value=out, height=320)
153
+ except Exception as e:
154
+ st.error(f"Local generation failed: {e}")
155
+
156
+ st.markdown("---")
157
+ st.caption("If you run into permissions errors in Spaces, ensure the HF cache env vars above point to a writable directory (we already set them to /tmp/huggingface in this container).")