Spaces:
Sleeping
Sleeping
| # chainlit run app.py -w | |
| # Standard library imports | |
| import asyncio | |
| import io | |
| import json | |
| import os | |
| import re | |
| import requests | |
| import zipfile | |
| # Data handling | |
| import pandas as pd | |
| # Environment variables | |
| from dotenv import load_dotenv | |
| # Typing for function signatures | |
| from typing import Any, List, Optional | |
| # Bioinformatics | |
| from Bio import Entrez, Medline | |
| # ChainLit specific imports | |
| import chainlit as cl | |
| from chainlit.types import AskFileResponse | |
| # Langchain imports for AI and chat models | |
| from langchain.chains import ConversationalRetrievalChain, LLMChain | |
| from langchain_community.chat_models import ChatOpenAI | |
| from langchain.docstore.document import Document | |
| from langchain.evaluation import StringEvaluator | |
| from langchain.memory import ChatMessageHistory, ConversationBufferMemory | |
| from langchain.prompts import PromptTemplate | |
| from langchain.prompts.chat import ( | |
| ChatPromptTemplate, | |
| SystemMessagePromptTemplate, | |
| HumanMessagePromptTemplate, | |
| ) | |
| from langchain.smith import RunEvalConfig, run_on_dataset | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler | |
| from langchain_openai import OpenAI, OpenAIEmbeddings | |
| # Vector storage and document loading | |
| from langchain_community.document_loaders import DataFrameLoader | |
| from langchain_community.vectorstores import Qdrant | |
| from qdrant_client import QdrantClient | |
| from qdrant_client import AsyncQdrantClient | |
| # Custom evaluations | |
| from custom_eval import PharmAssistEvaluator, HarmfulnessEvaluator, AIDetectionEvaluator | |
| # LangSmith for client interaction | |
| from langsmith import Client | |
| langsmith_client = Client() | |
| # Load environment variables from a .env file | |
| load_dotenv() | |
| # Define system template for the chatbot | |
| system_template = """ | |
| You are , an AI assistant for pharmacists and pharmacy students. Use the following pieces of context to answer the user's question. | |
| If you don't know the answer, simply state that you don't have enough information to provide an answer. Do not attempt to make up an answer. | |
| ALWAYS include a "SOURCES" section at the end of your response, referencing the specific documents from which you derived your answer. | |
| If the user greets you with a greeting like "Hi", "Hello", or "How are you", respond in a friendly manner. | |
| Example response format: | |
| <answer> | |
| SOURCES: <document_references> | |
| Begin! | |
| ---------------- | |
| {summaries} | |
| """ | |
| # Define messages for the chatbot prompt | |
| messages = [ | |
| SystemMessagePromptTemplate.from_template(system_template), | |
| HumanMessagePromptTemplate.from_template("{question}"), | |
| ] | |
| prompt = ChatPromptTemplate.from_messages(messages) | |
| chain_type_kwargs = {"prompt": prompt} | |
| qdrant_vectorstore = None | |
| # Function to search for related papers on PubMed | |
| async def search_related_papers(query, max_results=3): | |
| """ | |
| Search PubMed for papers related to the provided query and return a list of formatted strings with paper details and URLs. | |
| """ | |
| try: | |
| # Set up Entrez email (replace with your email) | |
| Entrez.email = os.environ.get("ENTREZ_EMAIL") | |
| # Search PubMed for related papers | |
| handle = Entrez.esearch(db="pubmed", term=query, retmax=max_results) | |
| record = Entrez.read(handle) | |
| handle.close() | |
| # Retrieve the details of the related papers | |
| id_list = record["IdList"] | |
| if not id_list: | |
| return ["No directly related papers found. Try broadening your search query."] | |
| handle = Entrez.efetch(db="pubmed", id=id_list, rettype="medline", retmode="text") | |
| records = Medline.parse(handle) | |
| related_papers = [] | |
| for record in records: | |
| title = record.get("TI", "") | |
| authors = ", ".join(record.get("AU", [])) | |
| citation = f"{authors}. {title}. {record.get('SO', '')}" | |
| url = f"https://pubmed.ncbi.nlm.nih.gov/{record['PMID']}/" | |
| related_papers.append(f"[{citation}]({url})") | |
| if not related_papers: | |
| related_papers = ["No directly related papers found. Try broadening your search query."] | |
| return related_papers | |
| except Exception as e: | |
| print(f"Error occurred while searching for related papers: {e}") | |
| return ["An error occurred while searching for related papers. Please try again later."] | |
| # Function to generate related questions based on retrieved results | |
| async def generate_related_questions(retrieved_results, num_questions=2, max_tokens=50): | |
| """ | |
| Generate related questions based on the provided retrieved results from a document store. | |
| """ | |
| llm = OpenAI(temperature=0.7) | |
| prompt = PromptTemplate( | |
| input_variables=["context"], | |
| template="Given the following context, generate {num_questions} related questions:\n\nContext: {context}\n\nQuestions:", | |
| ) | |
| chain = LLMChain(llm=llm, prompt=prompt) | |
| context = " ".join([doc.page_content for doc in retrieved_results]) | |
| generated_questions = chain.run(context=context, num_questions=num_questions, max_tokens=max_tokens) | |
| # Remove numbering from the generated questions | |
| related_questions = [question.split(". ", 1)[-1] for question in generated_questions.split("\n") if question.strip()] | |
| return related_questions | |
| # Function to generate answer based on user's query | |
| async def generate_answer(query): | |
| """ | |
| Generate an answer to the user's query using a conversational retrieval chain and handle callbacks for related questions and papers. | |
| """ | |
| # Initialize a message history to track the conversation | |
| message_history = ChatMessageHistory() | |
| # Set up memory to hold the conversation context and return answers | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| output_key="answer", | |
| chat_memory=message_history, | |
| return_messages=True, | |
| ) | |
| # Create a retrieval chain combining the LLM and the retriever | |
| chain = ConversationalRetrievalChain.from_llm( | |
| ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True), | |
| chain_type="stuff", | |
| retriever=qdrant_vectorstore.as_retriever(), | |
| memory=memory, | |
| return_source_documents=True, | |
| ) | |
| try: | |
| # Define callback handler for asynchronous operations | |
| cb = cl.AsyncLangchainCallbackHandler() | |
| feedback_callback = EvaluatorCallbackHandler(evaluators=[PharmAssistEvaluator(),HarmfulnessEvaluator(),AIDetectionEvaluator()]) | |
| # Process the incoming message using the conversational chain | |
| res = await chain.acall(query, callbacks=[cb,feedback_callback]) | |
| answer = res["answer"] | |
| source_documents = res["source_documents"] | |
| if answer.lower().startswith("i don't know") or answer.lower().startswith("i don't have enough information"): | |
| return answer, [], [], [],[] | |
| text_elements = [] | |
| if source_documents: | |
| for source_idx, source_doc in enumerate(source_documents): | |
| source_name = f"source_{source_idx}" | |
| text_elements.append( | |
| cl.Text(content=source_doc.page_content, name=source_name) | |
| ) | |
| source_names = [text_el.name for text_el in text_elements] | |
| if source_names: | |
| answer += f"\n\n**SOURCES:** {', '.join(source_names)}" | |
| else: | |
| answer += "\n\n**SOURCES:** No sources found" | |
| related_questions = await generate_related_questions(source_documents) | |
| related_question_actions = [ | |
| cl.Action(name="related_question", value=question.strip(), label=question.strip()) | |
| for question in related_questions if question.strip() | |
| ] | |
| # Search for related papers on PubMed | |
| related_papers = await search_related_papers(query) | |
| return answer, text_elements, related_question_actions, related_papers, query | |
| except Exception as e: | |
| print(f"Error occurred: {e}") | |
| return "An error occurred while processing your request. Please try again later.", [], [], [],[], query | |
| # Action callback for related question selection | |
| async def on_related_question_selected(action: cl.Action): | |
| """ | |
| Handle the selection of a related question, generate and send answers and further interactions. | |
| """ | |
| question = action.value | |
| await cl.Message(content=question, author="User").send() | |
| answer, text_elements, related_question_actions, related_papers, query = await generate_answer(question) | |
| # Send the processed answer back to the user | |
| await cl.Message(content=answer, elements=text_elements, author="PharmAssistAI").send() | |
| # Send related questions as a separate message | |
| if related_question_actions: | |
| await cl.Message(content="**Related Questions:**", actions=related_question_actions, author="PharmAssistAI").send() | |
| # Send related papers as a separate message | |
| if related_papers: | |
| related_papers_content = "**Related Papers from PubMed:**\n" + "\n".join(f"- {paper}" for paper in related_papers) | |
| await cl.Message(content=related_papers_content, author="PharmAssistAI").send() | |
| # Action callback for question selection | |
| async def on_question_selected(action: cl.Action): | |
| """ | |
| Respond to user-selected questions from suggested list, generate and send the answers. | |
| """ | |
| question = action.value | |
| await cl.Message(content=question, author="User").send() | |
| answer, text_elements, related_question_actions, related_papers,query = await generate_answer(question) | |
| await cl.Message(content=answer, elements=text_elements, author="").send() | |
| # Send related questions as a separate message | |
| if related_question_actions: | |
| await cl.Message(content="**Related Questions:**", actions=related_question_actions, author="").send() | |
| # Send related papers as a separate message | |
| if related_papers: | |
| related_papers_content = "**Related Papers from PubMed:**\n" + "\n".join(f"- {paper}" for paper in related_papers) | |
| await cl.Message(content=related_papers_content, author="").send() | |
| # Callback for chat start event | |
| async def on_chat_start(): | |
| """ | |
| Initialize the chatbot environment, load necessary data, and present initial user interactions. | |
| """ | |
| global qdrant_vectorstore | |
| # Display a preloader message | |
| await cl.Message(content="**Loading PharmAssistAI bot**....").send() | |
| await asyncio.sleep(2) # Add a 2-second delay to simulate loading | |
| # Adding logo for chatbot | |
| await cl.Avatar( | |
| name="", | |
| url="https://i.imgur.com/ZkIVmxp.jpeg", | |
| ).send() | |
| # Adding logo for user who is asking questions | |
| await cl.Avatar( | |
| name="User", | |
| url="https://i.imgur.com/XhmbgvT.jpeg", | |
| ).send() | |
| if qdrant_vectorstore is None: | |
| embedding_model = OpenAIEmbeddings(model="text-embedding-3-small") | |
| QDRANT_API_KEY=os.environ.get("QDRANT_API_KEY") | |
| QDRANT_CLUSTER_URL =os.environ.get("QDRANT_CLUSTER_URL") | |
| qdrant_client = AsyncQdrantClient(url=QDRANT_CLUSTER_URL, api_key=QDRANT_API_KEY,timeout=60) | |
| response = await qdrant_client.get_collections() | |
| # Extracting the collection names from the response | |
| collection_names = [collection.name for collection in response.collections] | |
| if "fda_drugs" not in collection_names: | |
| print("Collection 'fda_drugs' is not present.") | |
| # Download the data file | |
| url = "https://download.open.fda.gov/drug/label/drug-label-0001-of-0012.json.zip" | |
| response = requests.get(url) | |
| # Extract the JSON file from the zip | |
| zip_file = zipfile.ZipFile(io.BytesIO(response.content)) | |
| json_file = zip_file.open(zip_file.namelist()[0]) | |
| # Load the JSON data | |
| data = json.load(json_file) | |
| df = pd.json_normalize(data['results']) | |
| selected_drugs = df | |
| # Define metadata fields to include | |
| metadata_fields = ['openfda.brand_name', 'openfda.generic_name', 'openfda.manufacturer_name', | |
| 'openfda.product_type', 'openfda.route', 'openfda.substance_name', | |
| 'openfda.rxcui', 'openfda.spl_id', 'openfda.package_ndc'] | |
| # Define text fields to index | |
| text_fields = ['description', 'indications_and_usage', 'contraindications', | |
| 'warnings', 'adverse_reactions', 'dosage_and_administration'] | |
| # Replace NaN values with empty strings | |
| selected_drugs[text_fields] = selected_drugs[text_fields].fillna('') | |
| selected_drugs['content'] = selected_drugs[text_fields].apply(lambda x: ' '.join(x.astype(str)), axis=1) | |
| loader = DataFrameLoader(selected_drugs, page_content_column='content') | |
| drug_docs = loader.load() | |
| for doc, row in zip(drug_docs, selected_drugs.to_dict(orient='records')): | |
| metadata = {} | |
| for field in metadata_fields: | |
| value = row.get(field) | |
| if isinstance(value, list): | |
| value = ', '.join(str(v) for v in value if pd.notna(v)) | |
| elif pd.isna(value): | |
| value = 'Not Available' | |
| metadata[field] = value | |
| doc.metadata = metadata # Update the metadata to only include specified fields | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| split_drug_docs = text_splitter.split_documents(drug_docs) | |
| # Asynchronously create a Qdrant vector store with the document chunks | |
| qdrant_vectorstore = await cl.make_async(Qdrant.from_documents)( | |
| split_drug_docs, | |
| embedding_model, | |
| url=QDRANT_CLUSTER_URL, | |
| api_key=QDRANT_API_KEY, | |
| collection_name="fda_drugs" # Name of the collection in Qdrant | |
| ) | |
| else: | |
| print("Collection 'fda_drugs' is present.") | |
| # Load the existing collection | |
| qdrant_vectorstore = await cl.make_async(Qdrant.construct_instance)( | |
| texts=[""], # no texts to add | |
| embedding = embedding_model, | |
| url=QDRANT_CLUSTER_URL, | |
| api_key=QDRANT_API_KEY, | |
| collection_name="fda_drugs" # Name of the collection in Qdrant | |
| ) | |
| potential_questions = [ | |
| "What should I be careful of when taking Metformin?", | |
| "What are the contraindications of Aspirin?", | |
| "Are there low-cost alternatives to branded Aspirin available over-the-counter?", | |
| "What precautions should I take if I'm pregnant or nursing while on Lipitor?", | |
| "Should Lipitor be taken at a specific time of day, and does it need to be taken with food?", | |
| "What is the recommended dose of Aspirin?", | |
| "Can older people take beta blockers?", | |
| "How do beta blockers work?", | |
| "Can beta blockers be used for anxiety?", | |
| "I am taking Aspirin, is it ok to take Glipizide?", | |
| "Explain in simple terms how Metformin works?" | |
| ] | |
| await cl.Message( | |
| content="**Welcome to PharmAssistAI ! Here are some potential questions you can ask:**", | |
| actions=[cl.Action(name="ask_question", value=question, label=question) for question in potential_questions] | |
| ).send() | |
| cl.user_session.set("potential_questions_shown", True) | |
| # Main function to handle user messages | |
| async def main(message): | |
| """ | |
| Process user messages, generate and send responses, and handle further interactions based on the user's queries. | |
| """ | |
| query = message.content | |
| try: | |
| answer, text_elements, related_question_actions, related_papers, original_query = await generate_answer(query) | |
| # Create a new message with the answer and source documents | |
| answer_message = cl.Message(content=answer, elements=text_elements, author="PharmAssistAI") | |
| # Send the answer message | |
| await answer_message.send() | |
| if not answer.lower().startswith("i don't know") and not answer.lower().startswith("i don't have enough information"): | |
| # Send related questions as a separate message | |
| if related_question_actions: | |
| await cl.Message(content="**Related Questions:**", actions=related_question_actions, author="PharmAssistAI").send() | |
| # Send related papers as a separate message | |
| if related_papers: | |
| related_papers_content = "**Related Papers from PubMed:**\n" + "\n".join(f"- {paper}" for paper in related_papers) | |
| await cl.Message(content=related_papers_content, author="PharmAssistAI").send() | |
| except Exception as e: | |
| print(f"Error occurred: {e}") | |
| answer = "An error occurred while processing your request. Please try again later." | |
| await cl.Message(content=answer, author="PharmAssistAI").send() |