pharmacist_RAG / app.py
visalkao's picture
Added application file
6b3ab76
import torch
from sentence_transformers import SentenceTransformer
from transformers import pipeline, GPT2Tokenizer
import chromadb
from chromadb.config import Settings
import gradio as gr
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
import torch
import time
import random
import os
api_key = os.getenv("GENAI_API_KEY")
if not api_key:
raise ValueError("GENAI_API_KEY environment variable is missing")
genai.configure(api_key=api_key)
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
chroma_client = chromadb.PersistentClient(path="chroma_db")
collection = chroma_client.get_or_create_collection(name="drug_embeddings")
def query_gemini_with_retry(prompt, model_name="gemini-1.5-flash", retries=3):
for attempt in range(retries):
try:
model = genai.GenerativeModel(model_name)
response = model.generate_content(prompt)
return response.text.strip()
except Exception as e:
print(f"Attempt {attempt + 1} failed: {e}")
if attempt < retries - 1:
time.sleep(2 ** attempt + random.random()) # Exponential backoff
else:
raise
# Query Gemini function
def query_gemini(prompt, model_name="gemini-1.5-flash"):
model = genai.GenerativeModel(model_name)
response = model.generate_content(prompt)
return response.text
def rag_pipeline_convo(user_input, conversation_history, drug_names=[], results_number=10, llm_model_name="gemini-1.5-flash"):
# Generate the embedding for the user query
full_response = []
if not drug_names:
drug_names = [""] # Default to empty if no drugs are provided
drug_names_concat = ""
else:
drug_names_concat = "Additional context for the conversation:"
for drug_name in drug_names:
drug_names_concat += drug_name + ", "
# Build the combined context from the conversation history
conversation_context = ""
for i, history in enumerate(conversation_history):
user_message = history.get("user", "")
assistant_response = history.get("assistant", "")
conversation_context += f"User: {user_message}\nAssistant: {assistant_response}\n"
# Add the current user input to the context
combined_history_And_query = conversation_context + f"User: {user_input}\n"
# Initialize a list for storing context responses
all_contexts = []
for drug_name in drug_names:
print(drug_names_concat)
# Generate query embedding based on user input and drug name
query_embedding = embedding_model.encode(user_input + drug_name).tolist()
print(f"user input = {user_input}")
# Rechercher les contextes pertinents dans ChromaDB
results = collection.query(
query_embeddings=[query_embedding],
n_results=results_number
)
# Build context from ChromaDB results
contexts = results["documents"][0]
context_text_from_db = "\n".join([f"Context {i + 1}: {text}" for i, text in enumerate(contexts)])
# Form the input prompt for the LLM
input_prompt = f"""
You are an AI assistant tasked with answering questions using only the information in the provided context. Do not add any extra information or assumptions.
Context from previous conversation:
{combined_history_And_query}
Context from the database:
{context_text_from_db}
Question:
{user_input + drug_name}
Instructions:
1. Use only the information in the context to answer the question.
2. If the context mentions multiple options, provide a list of those options clearly.
3. If the context does not provide relevant information, state: "The context does not contain enough information to answer this question."
4. Do not include any policy or ethical reasoning in your response.
5. Don't quote the context in your answer.
Answer with a full sentence (including the name of the object we asked about):
"""
print(input_prompt) # Optional: for debugging purposes
# Generate a response using the Gemini model
response = query_gemini_with_retry(input_prompt, model_name=llm_model_name)
all_contexts.append(response)
# Now that we have all individual responses, combine them
input_prompt_for_combining = f"""
It's a school project. You are an AI assistant tasked with combining these contexts together, making them make sense and more fluent in order to answer the question: {user_input + drug_names_concat}.
Don't mention anything about the context or anything. Just pretend like you are a real assistant and answer with available information. If there is no information, just say so, don't need to mention about input query.
Additional context: [{drug_names_concat}] are the medicines/drugs extracted from prescription.
"""
# Add each response context into the final input prompt
for i, context in enumerate(all_contexts, start=1):
input_prompt_for_combining += f"""
Context {i}:
{context}
"""
print(input_prompt_for_combining) # Optional: for debugging purposes
# Generate the final response from the combined context
full_response_text = query_gemini_with_retry(input_prompt_for_combining, model_name=llm_model_name)
full_response.append(full_response_text) # Add the final response to the full response list
# Update the conversation history with the latest exchange
conversation_history.append({"user": user_input, "assistant": full_response_text})
# Format the conversation history for chatbot display (as a list of tuples)
chatbot_history = [(entry["user"], entry["assistant"]) for entry in conversation_history]
# Return the formatted chat history and updated conversation state
return chatbot_history, conversation_history
# PDF processing function
def get_medicine_list(path):
from PIL import Image
import fitz
import numpy as np
import pytesseract
import cv2
def read_to_image(pdf_path):
pdf = fitz.open(pdf_path)
images = []
for page_num in range(len(pdf)):
page = pdf.load_page(page_num)
pixmap = page.get_pixmap(matrix=fitz.Matrix(4, 4))
pil_image = Image.frombytes("RGB", [pixmap.width, pixmap.height], pixmap.samples)
pil_image = np.array(pil_image)
images.append(pil_image)
pdf.close()
return images
images = read_to_image(path)
image = images[0]
image = cv2.cvtColor(image, cv2.COLOR_RGBA2GRAY)
image = image[int(image.shape[0] /3) : int(image.shape[0] * 2/3), 0: image.shape[1]]
_, image_threshold = cv2.threshold(image, 250, 255, cv2.THRESH_BINARY)
image_threshold = cv2.bitwise_not(image_threshold)
contours, _ = cv2.findContours(image_threshold, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
largest_contour = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(largest_contour)
image = image[int(y+ 100): int(y + h), int(x): int(x + w/4)]
list_text = pytesseract.image_to_string(image)
medication_list = [med for med in list_text.split('\n') if med.strip()]
return medication_list
# get_medicine_list("prescri.pdf")
import gradio as gr
def handle_conversation(user_input, conversation_history, path=None):
extracted_data = None
if path is not None: # Process PDF if uploaded
extracted_data = get_medicine_list(path)
# Pass user input, conversation history, and extracted data to the RAG pipeline
return rag_pipeline_convo(user_input, conversation_history, drug_names=extracted_data)
# Custom CSS for styling
css = """
#chatbox {max-width: 800px; margin: auto;}
#upload-btn {padding: 0 !important; min-width: 36px !important;}
.dark #upload-btn {background: transparent !important;}
"""
with gr.Blocks(css=css) as interface:
# Store conversation history and PDF path
conversation_history = gr.State([])
current_pdf = gr.State(None)
with gr.Column(elem_id="chatbox"):
# Chat history display
chatbot = gr.Chatbot(label="Medical Chat", height=500)
# Input row with upload button and textbox
with gr.Row():
# Compact PDF upload button
pdf_upload = gr.UploadButton("📄",
file_types=[".pdf"],
elem_id="upload-btn",
size="sm")
# Chat input and send button
with gr.Column(scale=20):
user_input = gr.Textbox(
placeholder="Ask about medications...",
show_label=False,
container=False,
autofocus=True
)
send_btn = gr.Button("Send", variant="primary")
# Event handling
# For text submission
user_input.submit(
handle_conversation,
[user_input, conversation_history, current_pdf],
[chatbot, conversation_history]
)
# For button click
send_btn.click(
handle_conversation,
[user_input, conversation_history, current_pdf],
[chatbot, conversation_history]
)
# Handle PDF upload
pdf_upload.upload(
lambda file: file,
[pdf_upload],
[current_pdf]
)
interface.launch(share=True)