Kennethdotse commited on
Commit
652d9c6
·
1 Parent(s): 398f908
Files changed (2) hide show
  1. app.py +215 -66
  2. requirements.txt +14 -0
app.py CHANGED
@@ -1,70 +1,219 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
1
+ import os
2
+ import sys
3
  import gradio as gr
4
+ import torch
5
+ from transformers import pipeline, BitsAndBytesConfig
6
+ from datasets import load_dataset
7
+ import pandas as pd
8
+ from PIL import Image
9
+ from typing import Optional
10
+ from pathlib import Path
11
+ from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain_community.vectorstores import Chroma
13
+ from langchain.document_loaders import DataFrameLoader, PyPDFLoader, CSVLoader
14
+ from langchain.text_splitter import CharacterTextSplitter
15
+ from huggingface_hub import HfApi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
+ # ---------- Configuration ----------
19
+ MODEL_VARIANT = os.environ.get("MODEL_VARIANT", "4b-it")
20
+ MODEL_ID = f"google/medgemma-{MODEL_VARIANT}"
21
+ USE_QUANTIZATION = True
22
+ LOCAL_DOCS_PATH = Path("./medical/hb_db")
23
+ CHROMA_PERSIST_DIR = "./chroma_db"
24
+
25
+ _pipe = None
26
+ _rag_vectorstore = None
27
+ _embeddings = None
28
+
29
+ HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
30
+ if not HF_TOKEN:
31
+ print("Error: no Hugging Face token found. Set HF_TOKEN or HUGGINGFACEHUB_API_TOKEN as an environment variable or Space secret.")
32
+ sys.exit(1)
33
+ else:
34
+ try:
35
+ HfApi().whoami(token=HF_TOKEN)
36
+ print("Hugging Face token OK")
37
+ except Exception as e:
38
+ print("Invalid Hugging Face token:", e)
39
+ sys.exit(1)
40
+
41
+ # ---------- Lazy initialization helpers ----------
42
+ def _init_pipeline():
43
+ global _pipe
44
+ if _pipe is not None:
45
+ return _pipe
46
+
47
+ # Model kwargs
48
+ model_kwargs = dict(
49
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
50
+ device_map="auto",
51
+ )
52
+
53
+ if USE_QUANTIZATION:
54
+ try:
55
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
56
+ except Exception:
57
+ # bitsandbytes may not be available on CPU-only setups; ignore and fall back
58
+ pass
59
+
60
+ # Choose pipeline task type depending on variant
61
+ task = "image-text-to-text" if "image" in MODEL_VARIANT or "it" in MODEL_VARIANT else "text-generation"
62
+
63
+ print(f"Initializing pipeline: {MODEL_ID} task={task}")
64
+ _pipe = pipeline(
65
+ task,
66
+ model=MODEL_ID,
67
+ device_map=model_kwargs.get("device_map"),
68
+ torch_dtype=model_kwargs.get("torch_dtype"),
69
+ use_auth_token=HF_TOKEN,
70
+ **({} if "quantization_config" not in model_kwargs else {"quantization_config": model_kwargs["quantization_config"]}),
71
+ )
72
+ try:
73
+ _pipe.model.generation_config.do_sample = False
74
+ except Exception:
75
+ pass
76
+
77
+ return _pipe
78
+
79
+
80
+ def _init_rag():
81
+ """Builds or loads a Chroma vectorstore from local files. This runs lazily on first request."""
82
+ global _rag_vectorstore, _embeddings
83
+ if _rag_vectorstore is not None:
84
+ return _rag_vectorstore
85
+
86
+ docs = []
87
+
88
+ # 1) Load a Hugging Face dataset (if available) — convert to a DataFrame
89
+ try:
90
+ ds = load_dataset("knowrohit07/know_medical_dialogue_v2")
91
+ df = pd.DataFrame(ds["train"])
92
+ if "instruction" in df.columns and "output" in df.columns:
93
+ df["full_dialogue"] = df["instruction"].astype(str) + " \n\n" + df["output"].astype(str)
94
+ loader = DataFrameLoader(df, page_content_column="full_dialogue")
95
+ docs += loader.load()
96
+ except Exception as e:
97
+ print("Warning: could not load HF dataset:", e)
98
+
99
+ # 2) Load local CSV if present
100
+ csv_path = LOCAL_DOCS_PATH / "Final_Dataset.csv"
101
+ if csv_path.exists():
102
+ try:
103
+ csv_loader = CSVLoader(str(csv_path))
104
+ docs += csv_loader.load()
105
+ except Exception as e:
106
+ print("Warning loading CSV:", e)
107
+
108
+ # 3) Load PDFs found in the directory
109
+ if LOCAL_DOCS_PATH.exists() and LOCAL_DOCS_PATH.is_dir():
110
+ for pdf_file in LOCAL_DOCS_PATH.glob("*.pdf"):
111
+ try:
112
+ pdf_loader = PyPDFLoader(str(pdf_file))
113
+ docs += pdf_loader.load()
114
+ except Exception as e:
115
+ print(f"Warning loading PDF {pdf_file}: {e}")
116
+
117
+ # 4) If still no docs, create a placeholder document
118
+ if len(docs) == 0:
119
+ from langchain.schema import Document
120
+ docs = [Document(page_content="No local documents found. Upload PDFs/CSV into ./medical/hb_db or commit them to the Space repo.")]
121
+
122
+ # 5) Split into chunks
123
+ splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
124
+ chunks = splitter.split_documents(docs)
125
+
126
+ # 6) Embeddings and Chroma vectorstore
127
+ try:
128
+ _embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
129
+ _rag_vectorstore = Chroma.from_documents(chunks, _embeddings, persist_directory=CHROMA_PERSIST_DIR)
130
+ try:
131
+ _rag_vectorstore.persist()
132
+ except Exception:
133
+ pass
134
+ except Exception as e:
135
+ print("Error initializing vectorstore:", e)
136
+ _rag_vectorstore = None
137
+
138
+ return _rag_vectorstore
139
+
140
+
141
+ # ---------- Main RAG + generation function ----------
142
+
143
+ def generate_medgemma_rag_response(query: str, image: Optional[Image.Image] = None) -> str:
144
+ """Generate an answer using RAG + MedGemma model. This function will lazily initialize heavy resources."""
145
+ # Ensure rag is initialized
146
+ vs = _init_rag()
147
+
148
+ # Retrieve relevant docs if vectorstore exists
149
+ context = ""
150
+ if vs is not None:
151
+ try:
152
+ retrieved = vs.similarity_search(query, k=4)
153
+ context = "\n\n".join([d.page_content for d in retrieved])
154
+ except Exception as e:
155
+ print("Warning during similarity search:", e)
156
+
157
+ # Construct prompt
158
+ rag_prompt = f"You are a respectful, medical AI assistant. Use the provided context and your knowledge to answer and be clear when uncertain.\n\nContext:\n{context}\n\nUser Question: {query}\n\nAnswer:\n"
159
+
160
+ # Initialize pipeline lazily
161
+ pipe = _init_pipeline()
162
+
163
+ # Build input for the pipeline. The exact expected format can vary by pipeline task.
164
+ if image is not None:
165
+ # Provide an image + text prompt; pipeline expects inputs in a tuple/list depending on model
166
+ input_for_pipe = {"image": image, "text": rag_prompt}
167
+ try:
168
+ out = pipe(input_for_pipe, max_new_tokens=512)
169
+ except Exception:
170
+ # fallback to plain text prompt if image pipeline fails
171
+ out = pipe(rag_prompt, max_new_tokens=512)
172
+ else:
173
+ out = pipe(rag_prompt, max_new_tokens=512)
174
+
175
+ # Normalize output — many pipelines return a list of dicts
176
+ try:
177
+ if isinstance(out, list) and len(out) > 0:
178
+ # Prefer a sensible key if present
179
+ if isinstance(out[0], dict):
180
+ text = out[0].get("generated_text") or out[0].get("text") or str(out[0])
181
+ else:
182
+ text = str(out[0])
183
+ else:
184
+ text = str(out)
185
+ except Exception:
186
+ text = str(out)
187
+
188
+ return text
189
+
190
+
191
+ # ...existing code...
192
+ with gr.Blocks() as iface:
193
+ chatbot = gr.Chatbot(label="Ayaresa chat")
194
+ with gr.Row():
195
+ with gr.Column(scale=3):
196
+ txt = gr.Textbox(label="Enter a prompt", placeholder="Type your question here...", lines=2)
197
+ with gr.Column(scale=1):
198
+ img = gr.Image(type="pil", label="Image (optional)")
199
+ with gr.Row():
200
+ send = gr.Button("Send")
201
+ clear = gr.Button("Clear")
202
+
203
+ # keep conversation state explicitly
204
+ state = gr.State([])
205
+
206
+ def submit_fn(message, image, history):
207
+ history = history or []
208
+ if (not message or message.strip() == "") and image is None:
209
+ return history, "", history
210
+ resp = generate_medgemma_rag_response(message or "", image)
211
+ history.append((message or "", resp))
212
+ return history, "", history
213
+
214
+ send.click(submit_fn, inputs=[txt, img, state], outputs=[chatbot, txt, state])
215
+ txt.submit(submit_fn, inputs=[txt, img, state], outputs=[chatbot, txt, state])
216
+ clear.click(lambda: ([], "", []), inputs=None, outputs=[chatbot, txt, state])
217
+
218
  if __name__ == "__main__":
219
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.10.0
2
+ torch
3
+ transformers
4
+ datasets
5
+ pandas
6
+ Pillow
7
+ langchain
8
+ langchain-community
9
+ chromadb
10
+ sentence-transformers
11
+ pypdf
12
+ bitsandbytes
13
+ accelerate
14
+ huggingface-hub