Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
| from langchain_groq import ChatGroq | |
| from langchain.chains import RetrievalQA | |
| from langchain_core.documents import Document | |
| # LLM setup using Groq | |
| llm = ChatGroq( | |
| groq_api_key=os.getenv("GROQ_API_KEY"), | |
| model_name="llama-3.3-70b-versatile", | |
| temperature=0.3 | |
| ) | |
| # Prompt template for summarization | |
| prompt_template = ChatPromptTemplate.from_messages([ | |
| ("system", | |
| "You are a compassionate and skilled medical assistant. " | |
| "Your job is to read medical documents and explain them in simple, clear, and friendly language so patients and their families can understand. " | |
| "Summarize test results, diagnoses, treatments, and advice in an easy-to-follow, conversational tone.Dont use asterisks and em dashes " | |
| "Avoid sounding robotic or overly technical. " | |
| "If the document is not related to the medical field, respond with: 'I cannot generate a summary for non-medical documents.'"), | |
| ("human", | |
| "The user has uploaded the following document:\n{context}\n" | |
| "If it’s a medical report, write a warm and patient-friendly summary including:\n" | |
| "What the patient is diagnosed with" | |
| "Symptoms they experienced" | |
| "What tests were done and what they showed" | |
| "What treatments were given" | |
| "Clear discharge advice in bullet points" | |
| "If it's not medical, just reply: 'I cannot generate a summary for non-medical documents." | |
| "use short summary in simple english terms for each") | |
| ]) | |
| output_parser = StrOutputParser() | |
| # Global state | |
| last_uploaded_pdf = None | |
| vectorstore = None | |
| # PDF summarization | |
| def summarize_pdf(pdf_file): | |
| global last_uploaded_pdf, vectorstore | |
| if pdf_file is None: | |
| return "Please upload a PDF file." | |
| last_uploaded_pdf = pdf_file | |
| loader = PyPDFLoader(pdf_file.name) | |
| pages = loader.load_and_split() | |
| full_context = "\n\n".join([page.page_content for page in pages]) | |
| # Build the vector store | |
| embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-small-en-v1.5") | |
| vectorstore = FAISS.from_documents(pages, embeddings) | |
| chain = prompt_template | llm | output_parser | |
| summary = chain.invoke({"context": full_context}) | |
| return summary | |
| # Helper to check medical relevance of query | |
| def is_medical_query_and_context_available(query, vectorstore, threshold=20): | |
| # Step 1: Use LLM to classify query as medical or not | |
| # classification_prompt = ChatPromptTemplate.from_messages([ | |
| # ("system", "You are a helpful assistant. Determine whether the given user question is about a medical topic."), | |
| # ("human", f"Is the following query related to the medical field? Answer with only 'Yes' or 'No'.\n\nQuery: {query}") | |
| # ]) | |
| classification_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are a helpful assistant. Determine whether the given user question is about a medical topic."), | |
| ("human", "Is the following query related to the medical field? Answer with only 'Yes' or 'No'.\n\nQuery: {query}") | |
| ]) | |
| classification_chain = classification_prompt | llm | output_parser | |
| # classification_result = classification_chain.invoke({}) | |
| classification_result = classification_chain.invoke({"query": query}) | |
| is_medical = classification_result.strip().lower().startswith("yes") | |
| # Step 2: Use vectorstore to check for contextual relevance | |
| is_context_available = False | |
| if is_medical: | |
| docs = vectorstore.similarity_search(query, k=3) | |
| is_context_available = any(len(doc.page_content.strip()) > threshold for doc in docs) | |
| return is_medical, is_context_available | |
| return any(word in query.lower() for word in keywords) | |
| # Question answering | |
| def answer_question(user_question): | |
| global last_uploaded_pdf, vectorstore | |
| if not user_question.strip(): | |
| return "Please enter a valid question." | |
| if last_uploaded_pdf is None or vectorstore is None: | |
| return "Please upload and summarize a medical PDF first." | |
| # Use LLM + vector check | |
| is_medical, is_context_available = is_medical_query_and_context_available(user_question, vectorstore) | |
| if not is_medical: | |
| return "This question doesn't appear to be medical-related. Please ask a medical question." | |
| if not is_context_available: | |
| return "Sorry, I couldn't find relevant information in the uploaded document to answer your question." | |
| # Proceed with valid medical + relevant context | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| retriever=retriever, | |
| chain_type="stuff" | |
| ) | |
| return qa_chain.run(user_question) | |
| # Gradio UI | |
| with gr.Blocks(css=".gradio-container {max-width: 900px; margin: auto;}") as demo: | |
| gr.Markdown("<h1 style='text-align: center;'>🩺 Medical Report Simplifier</h1>") | |
| gr.Markdown("Upload a medical PDF to generate a simplified summary and ask follow-up questions.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File(label="📄 Upload Medical PDF", file_types=[".pdf"]) | |
| summarize_btn = gr.Button("🔍 Generate Summary") | |
| summary_output = gr.Textbox(label="📝 Patient-Friendly Summary", lines=10) | |
| with gr.Column(): | |
| user_question = gr.Textbox(label="❓ Ask a Follow-Up Question") | |
| ask_btn = gr.Button("💬 Ask") | |
| answer_output = gr.Textbox(label="🤖 Answer", lines=5) | |
| summarize_btn.click(fn=summarize_pdf, inputs=file_input, outputs=summary_output) | |
| ask_btn.click(fn=answer_question, inputs=user_question, outputs=answer_output) | |
| gr.Markdown("<hr>") | |
| gr.Markdown("<center>Built with ❤️ using LangChain, Groq, and Gradio</center>") | |
| demo.launch() | |