Spaces:
Runtime error
Runtime error
File size: 7,576 Bytes
652d9c6 f713b11 652d9c6 513c401 91fe1e0 513c401 91fe1e0 f713b11 652d9c6 f713b11 652d9c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import os
import sys
import gradio as gr
import torch
from transformers import pipeline, BitsAndBytesConfig
from datasets import load_dataset
import pandas as pd
from PIL import Image
from typing import Optional
from pathlib import Path
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders.dataframe import DataFrameLoader
from langchain_text_splitters import CharacterTextSplitter
# ---------- Configuration ----------
MODEL_VARIANT = os.environ.get("MODEL_VARIANT", "4b-it")
MODEL_ID = f"google/medgemma-{MODEL_VARIANT}"
USE_QUANTIZATION = True
LOCAL_DOCS_PATH = Path("./medical/hb_db")
CHROMA_PERSIST_DIR = "./chroma_db"
_pipe = None
_rag_vectorstore = None
_embeddings = None
# ---------- Lazy initialization helpers ----------
def _init_pipeline():
global _pipe
if _pipe is not None:
return _pipe
# Model kwargs
model_kwargs = dict(
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
if USE_QUANTIZATION:
try:
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
except Exception:
# bitsandbytes may not be available on CPU-only setups; ignore and fall back
pass
# Choose pipeline task type depending on variant
task = "image-text-to-text" if "image" in MODEL_VARIANT or "it" in MODEL_VARIANT else "text-generation"
print(f"Initializing pipeline: {MODEL_ID} task={task}")
_pipe = pipeline(
task,
model=MODEL_ID,
device_map=model_kwargs.get("device_map"),
torch_dtype=model_kwargs.get("torch_dtype"),
**({} if "quantization_config" not in model_kwargs else {"quantization_config": model_kwargs["quantization_config"]}),
)
try:
_pipe.model.generation_config.do_sample = False
except Exception:
pass
return _pipe
def _init_rag():
"""Builds or loads a Chroma vectorstore from local files. This runs lazily on first request."""
global _rag_vectorstore, _embeddings
if _rag_vectorstore is not None:
return _rag_vectorstore
docs = []
# 1) Load a Hugging Face dataset (if available) — convert to a DataFrame
try:
ds = load_dataset("knowrohit07/know_medical_dialogue_v2")
df = pd.DataFrame(ds["train"])
if "instruction" in df.columns and "output" in df.columns:
df["full_dialogue"] = df["instruction"].astype(str) + " \n\n" + df["output"].astype(str)
loader = DataFrameLoader(df, page_content_column="full_dialogue")
docs += loader.load()
except Exception as e:
print("Warning: could not load HF dataset:", e)
# 2) Load local CSV if present
csv_path = LOCAL_DOCS_PATH / "Final_Dataset.csv"
if csv_path.exists():
try:
csv_loader = CSVLoader(str(csv_path))
docs += csv_loader.load()
except Exception as e:
print("Warning loading CSV:", e)
# 3) Load PDFs found in the directory
if LOCAL_DOCS_PATH.exists() and LOCAL_DOCS_PATH.is_dir():
for pdf_file in LOCAL_DOCS_PATH.glob("*.pdf"):
try:
pdf_loader = PyPDFLoader(str(pdf_file))
docs += pdf_loader.load()
except Exception as e:
print(f"Warning loading PDF {pdf_file}: {e}")
# 4) If still no docs, create a placeholder document
if len(docs) == 0:
from langchain.schema import Document
docs = [Document(page_content="No local documents found. Upload PDFs/CSV into ./medical/hb_db or commit them to the Space repo.")]
# 5) Split into chunks
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = splitter.split_documents(docs)
# 6) Embeddings and Chroma vectorstore
try:
_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
_rag_vectorstore = Chroma.from_documents(chunks, _embeddings, persist_directory=CHROMA_PERSIST_DIR)
try:
_rag_vectorstore.persist()
except Exception:
pass
except Exception as e:
print("Error initializing vectorstore:", e)
_rag_vectorstore = None
return _rag_vectorstore
# ---------- Main RAG + generation function ----------
def generate_medgemma_rag_response(query: str, image: Optional[Image.Image] = None) -> str:
"""Generate an answer using RAG + MedGemma model. This function will lazily initialize heavy resources."""
# Ensure rag is initialized
vs = _init_rag()
# Retrieve relevant docs if vectorstore exists
context = ""
if vs is not None:
try:
retrieved = vs.similarity_search(query, k=4)
context = "\n\n".join([d.page_content for d in retrieved])
except Exception as e:
print("Warning during similarity search:", e)
# Construct prompt
rag_prompt = f"You are a respectful, medical AI assistant. Use the provided context and your knowledge to answer and be clear when uncertain.\n\nContext:\n{context}\n\nUser Question: {query}\n\nAnswer:\n"
# Initialize pipeline lazily
pipe = _init_pipeline()
# Build input for the pipeline. The exact expected format can vary by pipeline task.
if image is not None:
# Provide an image + text prompt; pipeline expects inputs in a tuple/list depending on model
input_for_pipe = {"image": image, "text": rag_prompt}
try:
out = pipe(input_for_pipe, max_new_tokens=512)
except Exception:
# fallback to plain text prompt if image pipeline fails
out = pipe(rag_prompt, max_new_tokens=512)
else:
out = pipe(rag_prompt, max_new_tokens=512)
# Normalize output — many pipelines return a list of dicts
try:
if isinstance(out, list) and len(out) > 0:
# Prefer a sensible key if present
if isinstance(out[0], dict):
text = out[0].get("generated_text") or out[0].get("text") or str(out[0])
else:
text = str(out[0])
else:
text = str(out)
except Exception:
text = str(out)
return text
# ...existing code...
with gr.Blocks() as iface:
chatbot = gr.Chatbot(label="Ayaresa chat")
with gr.Row():
with gr.Column(scale=3):
txt = gr.Textbox(label="Enter a prompt", placeholder="Type your question here...", lines=2)
with gr.Column(scale=1):
img = gr.Image(type="pil", label="Image (optional)")
with gr.Row():
send = gr.Button("Send")
clear = gr.Button("Clear")
# keep conversation state explicitly
state = gr.State([])
def submit_fn(message, image, history):
history = history or []
if (not message or message.strip() == "") and image is None:
return history, "", history
resp = generate_medgemma_rag_response(message or "", image)
history.append((message or "", resp))
return history, "", history
send.click(submit_fn, inputs=[txt, img, state], outputs=[chatbot, txt, state])
txt.submit(submit_fn, inputs=[txt, img, state], outputs=[chatbot, txt, state])
clear.click(lambda: ([], "", []), inputs=None, outputs=[chatbot, txt, state])
if __name__ == "__main__":
iface.launch() |