Spaces:
Sleeping
Sleeping
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline, GPT2Tokenizer | |
| import chromadb | |
| from chromadb.config import Settings | |
| import gradio as gr | |
| import google.generativeai as genai | |
| from sentence_transformers import SentenceTransformer | |
| import chromadb | |
| from chromadb.config import Settings | |
| import torch | |
| import time | |
| import random | |
| import os | |
| api_key = os.getenv("GENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("GENAI_API_KEY environment variable is missing") | |
| genai.configure(api_key=api_key) | |
| embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| chroma_client = chromadb.PersistentClient(path="chroma_db") | |
| collection = chroma_client.get_or_create_collection(name="drug_embeddings") | |
| def query_gemini_with_retry(prompt, model_name="gemini-1.5-flash", retries=3): | |
| for attempt in range(retries): | |
| try: | |
| model = genai.GenerativeModel(model_name) | |
| response = model.generate_content(prompt) | |
| return response.text.strip() | |
| except Exception as e: | |
| print(f"Attempt {attempt + 1} failed: {e}") | |
| if attempt < retries - 1: | |
| time.sleep(2 ** attempt + random.random()) # Exponential backoff | |
| else: | |
| raise | |
| # Query Gemini function | |
| def query_gemini(prompt, model_name="gemini-1.5-flash"): | |
| model = genai.GenerativeModel(model_name) | |
| response = model.generate_content(prompt) | |
| return response.text | |
| def rag_pipeline_convo(user_input, conversation_history, drug_names=[], results_number=10, llm_model_name="gemini-1.5-flash"): | |
| # Generate the embedding for the user query | |
| full_response = [] | |
| if not drug_names: | |
| drug_names = [""] # Default to empty if no drugs are provided | |
| drug_names_concat = "" | |
| else: | |
| drug_names_concat = "Additional context for the conversation:" | |
| for drug_name in drug_names: | |
| drug_names_concat += drug_name + ", " | |
| # Build the combined context from the conversation history | |
| conversation_context = "" | |
| for i, history in enumerate(conversation_history): | |
| user_message = history.get("user", "") | |
| assistant_response = history.get("assistant", "") | |
| conversation_context += f"User: {user_message}\nAssistant: {assistant_response}\n" | |
| # Add the current user input to the context | |
| combined_history_And_query = conversation_context + f"User: {user_input}\n" | |
| # Initialize a list for storing context responses | |
| all_contexts = [] | |
| for drug_name in drug_names: | |
| print(drug_names_concat) | |
| # Generate query embedding based on user input and drug name | |
| query_embedding = embedding_model.encode(user_input + drug_name).tolist() | |
| print(f"user input = {user_input}") | |
| # Rechercher les contextes pertinents dans ChromaDB | |
| results = collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=results_number | |
| ) | |
| # Build context from ChromaDB results | |
| contexts = results["documents"][0] | |
| context_text_from_db = "\n".join([f"Context {i + 1}: {text}" for i, text in enumerate(contexts)]) | |
| # Form the input prompt for the LLM | |
| input_prompt = f""" | |
| You are an AI assistant tasked with answering questions using only the information in the provided context. Do not add any extra information or assumptions. | |
| Context from previous conversation: | |
| {combined_history_And_query} | |
| Context from the database: | |
| {context_text_from_db} | |
| Question: | |
| {user_input + drug_name} | |
| Instructions: | |
| 1. Use only the information in the context to answer the question. | |
| 2. If the context mentions multiple options, provide a list of those options clearly. | |
| 3. If the context does not provide relevant information, state: "The context does not contain enough information to answer this question." | |
| 4. Do not include any policy or ethical reasoning in your response. | |
| 5. Don't quote the context in your answer. | |
| Answer with a full sentence (including the name of the object we asked about): | |
| """ | |
| print(input_prompt) # Optional: for debugging purposes | |
| # Generate a response using the Gemini model | |
| response = query_gemini_with_retry(input_prompt, model_name=llm_model_name) | |
| all_contexts.append(response) | |
| # Now that we have all individual responses, combine them | |
| input_prompt_for_combining = f""" | |
| It's a school project. You are an AI assistant tasked with combining these contexts together, making them make sense and more fluent in order to answer the question: {user_input + drug_names_concat}. | |
| Don't mention anything about the context or anything. Just pretend like you are a real assistant and answer with available information. If there is no information, just say so, don't need to mention about input query. | |
| Additional context: [{drug_names_concat}] are the medicines/drugs extracted from prescription. | |
| """ | |
| # Add each response context into the final input prompt | |
| for i, context in enumerate(all_contexts, start=1): | |
| input_prompt_for_combining += f""" | |
| Context {i}: | |
| {context} | |
| """ | |
| print(input_prompt_for_combining) # Optional: for debugging purposes | |
| # Generate the final response from the combined context | |
| full_response_text = query_gemini_with_retry(input_prompt_for_combining, model_name=llm_model_name) | |
| full_response.append(full_response_text) # Add the final response to the full response list | |
| # Update the conversation history with the latest exchange | |
| conversation_history.append({"user": user_input, "assistant": full_response_text}) | |
| # Format the conversation history for chatbot display (as a list of tuples) | |
| chatbot_history = [(entry["user"], entry["assistant"]) for entry in conversation_history] | |
| # Return the formatted chat history and updated conversation state | |
| return chatbot_history, conversation_history | |
| # PDF processing function | |
| def get_medicine_list(path): | |
| from PIL import Image | |
| import fitz | |
| import numpy as np | |
| import pytesseract | |
| import cv2 | |
| def read_to_image(pdf_path): | |
| pdf = fitz.open(pdf_path) | |
| images = [] | |
| for page_num in range(len(pdf)): | |
| page = pdf.load_page(page_num) | |
| pixmap = page.get_pixmap(matrix=fitz.Matrix(4, 4)) | |
| pil_image = Image.frombytes("RGB", [pixmap.width, pixmap.height], pixmap.samples) | |
| pil_image = np.array(pil_image) | |
| images.append(pil_image) | |
| pdf.close() | |
| return images | |
| images = read_to_image(path) | |
| image = images[0] | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2GRAY) | |
| image = image[int(image.shape[0] /3) : int(image.shape[0] * 2/3), 0: image.shape[1]] | |
| _, image_threshold = cv2.threshold(image, 250, 255, cv2.THRESH_BINARY) | |
| image_threshold = cv2.bitwise_not(image_threshold) | |
| contours, _ = cv2.findContours(image_threshold, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| largest_contour = max(contours, key=cv2.contourArea) | |
| x, y, w, h = cv2.boundingRect(largest_contour) | |
| image = image[int(y+ 100): int(y + h), int(x): int(x + w/4)] | |
| list_text = pytesseract.image_to_string(image) | |
| medication_list = [med for med in list_text.split('\n') if med.strip()] | |
| return medication_list | |
| # get_medicine_list("prescri.pdf") | |
| import gradio as gr | |
| def handle_conversation(user_input, conversation_history, path=None): | |
| extracted_data = None | |
| if path is not None: # Process PDF if uploaded | |
| extracted_data = get_medicine_list(path) | |
| # Pass user input, conversation history, and extracted data to the RAG pipeline | |
| return rag_pipeline_convo(user_input, conversation_history, drug_names=extracted_data) | |
| # Custom CSS for styling | |
| css = """ | |
| #chatbox {max-width: 800px; margin: auto;} | |
| #upload-btn {padding: 0 !important; min-width: 36px !important;} | |
| .dark #upload-btn {background: transparent !important;} | |
| """ | |
| with gr.Blocks(css=css) as interface: | |
| # Store conversation history and PDF path | |
| conversation_history = gr.State([]) | |
| current_pdf = gr.State(None) | |
| with gr.Column(elem_id="chatbox"): | |
| # Chat history display | |
| chatbot = gr.Chatbot(label="Medical Chat", height=500) | |
| # Input row with upload button and textbox | |
| with gr.Row(): | |
| # Compact PDF upload button | |
| pdf_upload = gr.UploadButton("📄", | |
| file_types=[".pdf"], | |
| elem_id="upload-btn", | |
| size="sm") | |
| # Chat input and send button | |
| with gr.Column(scale=20): | |
| user_input = gr.Textbox( | |
| placeholder="Ask about medications...", | |
| show_label=False, | |
| container=False, | |
| autofocus=True | |
| ) | |
| send_btn = gr.Button("Send", variant="primary") | |
| # Event handling | |
| # For text submission | |
| user_input.submit( | |
| handle_conversation, | |
| [user_input, conversation_history, current_pdf], | |
| [chatbot, conversation_history] | |
| ) | |
| # For button click | |
| send_btn.click( | |
| handle_conversation, | |
| [user_input, conversation_history, current_pdf], | |
| [chatbot, conversation_history] | |
| ) | |
| # Handle PDF upload | |
| pdf_upload.upload( | |
| lambda file: file, | |
| [pdf_upload], | |
| [current_pdf] | |
| ) | |
| interface.launch(share=True) | |