visalkao commited on
Commit
6b3ab76
·
1 Parent(s): b7855e3

Added application file

Browse files
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sentence_transformers import SentenceTransformer
3
+ from transformers import pipeline, GPT2Tokenizer
4
+ import chromadb
5
+ from chromadb.config import Settings
6
+ import gradio as gr
7
+ import google.generativeai as genai
8
+ from sentence_transformers import SentenceTransformer
9
+ import chromadb
10
+ from chromadb.config import Settings
11
+ import torch
12
+ import time
13
+ import random
14
+ import os
15
+
16
+ api_key = os.getenv("GENAI_API_KEY")
17
+ if not api_key:
18
+ raise ValueError("GENAI_API_KEY environment variable is missing")
19
+
20
+ genai.configure(api_key=api_key)
21
+ embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
22
+ chroma_client = chromadb.PersistentClient(path="chroma_db")
23
+ collection = chroma_client.get_or_create_collection(name="drug_embeddings")
24
+
25
+
26
+ def query_gemini_with_retry(prompt, model_name="gemini-1.5-flash", retries=3):
27
+ for attempt in range(retries):
28
+ try:
29
+ model = genai.GenerativeModel(model_name)
30
+ response = model.generate_content(prompt)
31
+ return response.text.strip()
32
+ except Exception as e:
33
+ print(f"Attempt {attempt + 1} failed: {e}")
34
+ if attempt < retries - 1:
35
+ time.sleep(2 ** attempt + random.random()) # Exponential backoff
36
+ else:
37
+ raise
38
+ # Query Gemini function
39
+ def query_gemini(prompt, model_name="gemini-1.5-flash"):
40
+ model = genai.GenerativeModel(model_name)
41
+ response = model.generate_content(prompt)
42
+ return response.text
43
+ def rag_pipeline_convo(user_input, conversation_history, drug_names=[], results_number=10, llm_model_name="gemini-1.5-flash"):
44
+ # Generate the embedding for the user query
45
+ full_response = []
46
+
47
+ if not drug_names:
48
+ drug_names = [""] # Default to empty if no drugs are provided
49
+ drug_names_concat = ""
50
+ else:
51
+ drug_names_concat = "Additional context for the conversation:"
52
+ for drug_name in drug_names:
53
+ drug_names_concat += drug_name + ", "
54
+
55
+ # Build the combined context from the conversation history
56
+ conversation_context = ""
57
+ for i, history in enumerate(conversation_history):
58
+ user_message = history.get("user", "")
59
+ assistant_response = history.get("assistant", "")
60
+ conversation_context += f"User: {user_message}\nAssistant: {assistant_response}\n"
61
+
62
+ # Add the current user input to the context
63
+ combined_history_And_query = conversation_context + f"User: {user_input}\n"
64
+
65
+ # Initialize a list for storing context responses
66
+ all_contexts = []
67
+
68
+ for drug_name in drug_names:
69
+ print(drug_names_concat)
70
+
71
+ # Generate query embedding based on user input and drug name
72
+ query_embedding = embedding_model.encode(user_input + drug_name).tolist()
73
+ print(f"user input = {user_input}")
74
+ # Rechercher les contextes pertinents dans ChromaDB
75
+ results = collection.query(
76
+ query_embeddings=[query_embedding],
77
+ n_results=results_number
78
+ )
79
+
80
+ # Build context from ChromaDB results
81
+ contexts = results["documents"][0]
82
+ context_text_from_db = "\n".join([f"Context {i + 1}: {text}" for i, text in enumerate(contexts)])
83
+
84
+ # Form the input prompt for the LLM
85
+ input_prompt = f"""
86
+ You are an AI assistant tasked with answering questions using only the information in the provided context. Do not add any extra information or assumptions.
87
+ Context from previous conversation:
88
+ {combined_history_And_query}
89
+
90
+ Context from the database:
91
+ {context_text_from_db}
92
+
93
+ Question:
94
+ {user_input + drug_name}
95
+
96
+ Instructions:
97
+ 1. Use only the information in the context to answer the question.
98
+ 2. If the context mentions multiple options, provide a list of those options clearly.
99
+ 3. If the context does not provide relevant information, state: "The context does not contain enough information to answer this question."
100
+ 4. Do not include any policy or ethical reasoning in your response.
101
+ 5. Don't quote the context in your answer.
102
+
103
+ Answer with a full sentence (including the name of the object we asked about):
104
+ """
105
+ print(input_prompt) # Optional: for debugging purposes
106
+ # Generate a response using the Gemini model
107
+ response = query_gemini_with_retry(input_prompt, model_name=llm_model_name)
108
+ all_contexts.append(response)
109
+
110
+ # Now that we have all individual responses, combine them
111
+ input_prompt_for_combining = f"""
112
+ It's a school project. You are an AI assistant tasked with combining these contexts together, making them make sense and more fluent in order to answer the question: {user_input + drug_names_concat}.
113
+ Don't mention anything about the context or anything. Just pretend like you are a real assistant and answer with available information. If there is no information, just say so, don't need to mention about input query.
114
+
115
+ Additional context: [{drug_names_concat}] are the medicines/drugs extracted from prescription.
116
+ """
117
+
118
+ # Add each response context into the final input prompt
119
+ for i, context in enumerate(all_contexts, start=1):
120
+ input_prompt_for_combining += f"""
121
+ Context {i}:
122
+ {context}
123
+ """
124
+
125
+ print(input_prompt_for_combining) # Optional: for debugging purposes
126
+ # Generate the final response from the combined context
127
+ full_response_text = query_gemini_with_retry(input_prompt_for_combining, model_name=llm_model_name)
128
+ full_response.append(full_response_text) # Add the final response to the full response list
129
+
130
+ # Update the conversation history with the latest exchange
131
+ conversation_history.append({"user": user_input, "assistant": full_response_text})
132
+
133
+ # Format the conversation history for chatbot display (as a list of tuples)
134
+ chatbot_history = [(entry["user"], entry["assistant"]) for entry in conversation_history]
135
+
136
+ # Return the formatted chat history and updated conversation state
137
+ return chatbot_history, conversation_history
138
+
139
+
140
+
141
+ # PDF processing function
142
+ def get_medicine_list(path):
143
+ from PIL import Image
144
+ import fitz
145
+ import numpy as np
146
+ import pytesseract
147
+ import cv2
148
+
149
+ def read_to_image(pdf_path):
150
+ pdf = fitz.open(pdf_path)
151
+ images = []
152
+ for page_num in range(len(pdf)):
153
+ page = pdf.load_page(page_num)
154
+ pixmap = page.get_pixmap(matrix=fitz.Matrix(4, 4))
155
+ pil_image = Image.frombytes("RGB", [pixmap.width, pixmap.height], pixmap.samples)
156
+ pil_image = np.array(pil_image)
157
+ images.append(pil_image)
158
+ pdf.close()
159
+ return images
160
+
161
+ images = read_to_image(path)
162
+ image = images[0]
163
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2GRAY)
164
+ image = image[int(image.shape[0] /3) : int(image.shape[0] * 2/3), 0: image.shape[1]]
165
+ _, image_threshold = cv2.threshold(image, 250, 255, cv2.THRESH_BINARY)
166
+ image_threshold = cv2.bitwise_not(image_threshold)
167
+ contours, _ = cv2.findContours(image_threshold, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
168
+ largest_contour = max(contours, key=cv2.contourArea)
169
+ x, y, w, h = cv2.boundingRect(largest_contour)
170
+ image = image[int(y+ 100): int(y + h), int(x): int(x + w/4)]
171
+ list_text = pytesseract.image_to_string(image)
172
+ medication_list = [med for med in list_text.split('\n') if med.strip()]
173
+ return medication_list
174
+
175
+ # get_medicine_list("prescri.pdf")
176
+
177
+
178
+ import gradio as gr
179
+
180
+ def handle_conversation(user_input, conversation_history, path=None):
181
+ extracted_data = None
182
+ if path is not None: # Process PDF if uploaded
183
+ extracted_data = get_medicine_list(path)
184
+
185
+ # Pass user input, conversation history, and extracted data to the RAG pipeline
186
+ return rag_pipeline_convo(user_input, conversation_history, drug_names=extracted_data)
187
+
188
+ # Custom CSS for styling
189
+ css = """
190
+ #chatbox {max-width: 800px; margin: auto;}
191
+ #upload-btn {padding: 0 !important; min-width: 36px !important;}
192
+ .dark #upload-btn {background: transparent !important;}
193
+ """
194
+
195
+ with gr.Blocks(css=css) as interface:
196
+ # Store conversation history and PDF path
197
+ conversation_history = gr.State([])
198
+ current_pdf = gr.State(None)
199
+
200
+ with gr.Column(elem_id="chatbox"):
201
+ # Chat history display
202
+ chatbot = gr.Chatbot(label="Medical Chat", height=500)
203
+
204
+ # Input row with upload button and textbox
205
+ with gr.Row():
206
+ # Compact PDF upload button
207
+ pdf_upload = gr.UploadButton("📄",
208
+ file_types=[".pdf"],
209
+ elem_id="upload-btn",
210
+ size="sm")
211
+
212
+ # Chat input and send button
213
+ with gr.Column(scale=20):
214
+ user_input = gr.Textbox(
215
+ placeholder="Ask about medications...",
216
+ show_label=False,
217
+ container=False,
218
+ autofocus=True
219
+ )
220
+
221
+ send_btn = gr.Button("Send", variant="primary")
222
+
223
+ # Event handling
224
+ # For text submission
225
+ user_input.submit(
226
+ handle_conversation,
227
+ [user_input, conversation_history, current_pdf],
228
+ [chatbot, conversation_history]
229
+ )
230
+
231
+ # For button click
232
+ send_btn.click(
233
+ handle_conversation,
234
+ [user_input, conversation_history, current_pdf],
235
+ [chatbot, conversation_history]
236
+ )
237
+
238
+ # Handle PDF upload
239
+ pdf_upload.upload(
240
+ lambda file: file,
241
+ [pdf_upload],
242
+ [current_pdf]
243
+ )
244
+
245
+ interface.launch(share=True)
chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e331ff9dcd9a17dcf82fc191451c962c72c2ed83b5bb3b2dff8aa6331d0a3b98
3
+ size 121036800
chroma_db/e94c7dda-7540-4220-8d4e-b350abfb7daa/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6caf4d0af3acc1a7976a36b85692d1ce84d6da716c8ea5f19fa81fea6b65388d
3
+ size 53632000
chroma_db/e94c7dda-7540-4220-8d4e-b350abfb7daa/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bb2764f9ad50a5bac4c0e4c04de0c8b743ab15122f4f95fad6dd553440e41be
3
+ size 100
chroma_db/e94c7dda-7540-4220-8d4e-b350abfb7daa/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3407c8b7a220baefe0df7232212f5f8f4ab08661a94636e3842ed938eef660b0
3
+ size 851136
chroma_db/e94c7dda-7540-4220-8d4e-b350abfb7daa/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a202fa0f6c299fa081692f8474136c45005d13650b1cf6df0db035c43342f72c
3
+ size 128000
chroma_db/e94c7dda-7540-4220-8d4e-b350abfb7daa/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:422451f91cbccc8164da66c171d465bc9423fdb268602011e4c2c4b405cb195e
3
+ size 273248
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ sentence-transformers
3
+ transformers
4
+ chromadb
5
+ gradio
6
+ google-generative-ai