kirubel1738 commited on
Commit
d17b90e
·
verified ·
1 Parent(s): c9d2fa0

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +44 -145
src/streamlit_app.py CHANGED
@@ -1,157 +1,56 @@
1
-
2
  # streamlit_app.py
3
  import os
4
- import json
5
- import time
6
 
7
  # -----------------------------
8
- # IMPORTANT: set cache dirs BEFORE importing transformers/huggingface_hub
9
  # -----------------------------
10
- os.environ.setdefault("HF_HOME", os.environ.get("HF_HOME", "/tmp/huggingface"))
11
- os.environ.setdefault("TRANSFORMERS_CACHE", os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers"))
12
- os.environ.setdefault("HF_DATASETS_CACHE", os.environ.get("HF_DATASETS_CACHE", "/tmp/huggingface/datasets"))
13
- os.environ.setdefault("HUGGINGFACE_HUB_CACHE", os.environ.get("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub"))
14
- os.environ.setdefault("XDG_CACHE_HOME", os.environ.get("XDG_CACHE_HOME", "/tmp/huggingface"))
15
- os.environ.setdefault("HOME", os.environ.get("HOME", "/tmp"))
16
-
17
- # create cache dirs (best-effort)
18
- for d in [os.environ["HF_HOME"], os.environ["TRANSFORMERS_CACHE"], os.environ["HF_DATASETS_CACHE"], os.environ["HUGGINGFACE_HUB_CACHE"]]:
19
- try:
20
- os.makedirs(d, exist_ok=True)
21
- os.chmod(d, 0o777)
22
- except Exception:
23
- pass
24
-
25
- import streamlit as st
26
- import requests
27
-
28
- # Optional heavy imports will be inside local-model branch
29
- LOCAL_MODE = os.environ.get("USE_LOCAL_MODEL", "0") == "1"
30
-
31
- # default model id the user provided; keep as-is
32
- DEFAULT_MODEL_ID = "kirubel1738/biogpt-pubmedqa-finetuned"
33
-
34
- st.set_page_config(page_title="BioGPT (PubMedQA) demo", layout="centered")
35
 
36
- st.title("BioGPT PubMedQA demo")
37
- st.caption("Defaults to the Hugging Face Inference API (recommended for Spaces / CPU).")
38
 
39
- st.markdown(
40
- """
41
- **How it works**
42
- - By default the app will call Hugging Face's Inference API for the model you specify (fast and avoids memory issues).
43
- - If you set `USE_LOCAL_MODEL=1` in your environment, the app will attempt to load the model locally using `transformers` (only for GPUs/large memory machines).
44
- """
45
- )
46
 
47
- col1, col2 = st.columns([3,1])
 
48
 
49
- with col1:
50
- model_id = st.text_input("Model repo id", value=DEFAULT_MODEL_ID, help="Hugging Face repo id (e.g. username/modelname).")
51
- prompt = st.text_area("Question / prompt", height=180, placeholder="Enter a PubMed-style question or prompt...")
52
- with col2:
53
- max_new_tokens = st.slider("Max new tokens", 16, 1024, 128)
54
- temperature = st.slider("Temperature", 0.0, 1.5, 0.0, step=0.05)
55
- method = st.radio("Run method", ("Inference API (recommended)", "Local model (heavy)"), index=0)
56
-
57
- # override radio if user set USE_LOCAL_MODEL env var
58
- if LOCAL_MODE:
59
- method = "Local model (heavy)"
60
-
61
- hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_API_TOKEN")
62
-
63
- def call_inference_api(model_id: str, prompt: str, max_new_tokens: int, temperature: float):
64
- """
65
- Simple POST to Hugging Face Inference API.
66
- If you want to use the InferenceClient from huggingface_hub you can swap this.
67
- """
68
- api_url = f"https://api-inference.huggingface.co/models/{model_id}"
69
- headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
70
- payload = {
71
- "inputs": prompt,
72
- "parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature},
73
- "options": {"wait_for_model": True}
74
- }
75
- try:
76
- r = requests.post(api_url, headers=headers, json=payload, timeout=120)
77
- except Exception as e:
78
- return False, f"Request failed: {e}"
79
- if r.status_code != 200:
80
- try:
81
- error = r.json()
82
- except Exception:
83
- error = r.text
84
- return False, f"API error ({r.status_code}): {error}"
85
- try:
86
- resp = r.json()
87
- # handle several possible response schemas
88
- if isinstance(resp, dict) and "error" in resp:
89
- return False, resp["error"]
90
- # often it's a list of dicts with 'generated_text'
91
- if isinstance(resp, list):
92
- out_texts = []
93
- for item in resp:
94
- if isinstance(item, dict):
95
- # common key: 'generated_text'
96
- for k in ("generated_text", "text", "content"):
97
- if k in item:
98
- out_texts.append(item[k])
99
- break
100
- else:
101
- out_texts.append(json.dumps(item))
102
- else:
103
- out_texts.append(str(item))
104
- return True, "\n\n".join(out_texts)
105
- # fallback
106
- return True, str(resp)
107
- except Exception as e:
108
- return False, f"Could not parse response: {e}"
109
-
110
- # Local model loader (only if method chosen)
111
- generator = None
112
- if method.startswith("Local"):
113
- st.warning("Local model mode selected — this requires transformers + torch and lots of RAM/GPU. Only use if you know the model fits your hardware.")
114
- try:
115
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
116
- import torch
117
- device = 0 if torch.cuda.is_available() else -1
118
- st.info(f"torch.cuda.is_available={torch.cuda.is_available()} -- device set to {device}")
119
- with st.spinner("Loading tokenizer & model (this can take a while)..."):
120
- tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=os.environ.get("TRANSFORMERS_CACHE"))
121
- model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=os.environ.get("TRANSFORMERS_CACHE"), low_cpu_mem_usage=True)
122
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
123
- except Exception as e:
124
- st.error(f"Local model load failed: {e}")
125
- st.stop()
126
-
127
- if st.button("Generate"):
128
- if not prompt or prompt.strip() == "":
129
- st.error("Please enter a prompt.")
130
- st.stop()
131
-
132
- if method.startswith("Inference"):
133
- if ("kirubel1738/biogpt-pubmedqa-finetuned" in model_id) and not hf_token:
134
- st.info("If the model is private or rate-limited, set HUGGINGFACE_HUB_TOKEN as a secret in Spaces or as an env var locally.")
135
- with st.spinner("Querying Hugging Face Inference API..."):
136
- ok, out = call_inference_api(model_id, prompt, max_new_tokens, float(temperature))
137
- if not ok:
138
- st.error(out)
139
- else:
140
- st.success("Done")
141
- st.text_area("Model output", value=out, height=320)
142
  else:
143
- # local model generation
144
- try:
145
- with st.spinner("Running local generation..."):
146
- results = generator(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature)
147
- if isinstance(results, list) and len(results) > 0 and "generated_text" in results[0]:
148
- out = results[0]["generated_text"]
149
- else:
150
- out = str(results)
151
- st.success("Done")
152
- st.text_area("Model output", value=out, height=320)
153
- except Exception as e:
154
- st.error(f"Local generation failed: {e}")
155
 
156
  st.markdown("---")
157
- st.caption("If you run into permissions errors in Spaces, ensure the HF cache env vars above point to a writable directory (we already set them to /tmp/huggingface in this container).")
 
 
1
  # streamlit_app.py
2
  import os
3
+ import streamlit as st
4
+ from transformers import pipeline
5
 
6
  # -----------------------------
7
+ # Ensure cache dirs are writable in Spaces
8
  # -----------------------------
9
+ os.environ.setdefault("HF_HOME", "/tmp/huggingface")
10
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers")
11
+ os.environ.setdefault("HF_DATASETS_CACHE", "/tmp/huggingface/datasets")
12
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub")
13
+ os.environ.setdefault("XDG_CACHE_HOME", "/tmp/huggingface")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Hardcoded model repo
16
+ MODEL_ID = "kirubel1738/biogpt-pubmedqa-finetuned"
17
 
18
+ @st.cache_resource
19
+ def load_model():
20
+ """Load BioGPT model (on CPU)."""
21
+ generator = pipeline("text-generation", model=MODEL_ID, device=-1)
22
+ return generator
 
 
23
 
24
+ # Load once
25
+ generator = load_model()
26
 
27
+ # -----------------------------
28
+ # Streamlit UI
29
+ # -----------------------------
30
+ st.set_page_config(page_title="BioGPT — PubMedQA demo", layout="centered")
31
+ st.title("🧬 BioGPT PubMedQA Demo")
32
+
33
+ st.write("Ask a biomedical question and get an answer generated by BioGPT fine-tuned on PubMedQA.")
34
+
35
+ user_input = st.text_area("Enter your biomedical question:", height=150)
36
+
37
+ if st.button("Get Answer"):
38
+ if user_input.strip():
39
+ with st.spinner("Generating answer..."):
40
+ try:
41
+ result = generator(
42
+ user_input,
43
+ max_new_tokens=128,
44
+ do_sample=True,
45
+ temperature=0.7
46
+ )
47
+ output_text = result[0]["generated_text"]
48
+ st.success("Answer:")
49
+ st.write(output_text)
50
+ except Exception as e:
51
+ st.error(f"Generation failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  else:
53
+ st.warning("Please enter a question.")
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  st.markdown("---")
56
+ st.caption("Model: kirubel1738/biogpt-pubmedqa-finetuned | Runs on CPU")