import streamlit as st import json import os from datasets import load_dataset from langchain_community.document_loaders import JSONLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_groq import ChatGroq from uuid import uuid4 from pathlib import Path from langchain.schema import Document from langchain_huggingface import HuggingFaceEmbeddings from langchain_chroma import Chroma from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_community.chat_message_histories import ChatMessageHistory from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from IPython.display import Markdown, display from langchain_core.output_parsers import JsonOutputParser from langchain_core.pydantic_v1 import BaseModel, Field from typing import List, Optional from dataclasses import dataclass, field # Define the data folder path DATA_FOLDER = "data" # Ensure the data folder exists if not os.path.exists(DATA_FOLDER): os.makedirs(DATA_FOLDER) st.title("Triomics") # Option to upload a file or provide a local file path input_option = st.radio("Choose input method:", ("Upload a JSON file", "Autoload")) uploaded_file = None local_file_path_input = None if input_option == "Upload a JSON file": uploaded_file = st.file_uploader("Upload a JSON file", type=["json"]) elif input_option == "Autoload": local_file_path_input = "1.json" file_path_to_process = None file_name = None json_data = None if uploaded_file is not None: try: json_data = json.load(uploaded_file) file_name = uploaded_file.name file_path_to_process = os.path.join(DATA_FOLDER, file_name) except json.JSONDecodeError: st.error("Error: The uploaded file is not a valid JSON file.") st.stop() except Exception as e: st.error(f"An error occurred while processing the uploaded file: {e}") st.stop() elif local_file_path_input: if os.path.exists(local_file_path_input): try: with open(local_file_path_input, 'r') as f: json_data = json.load(f) file_name = os.path.basename(local_file_path_input) file_path_to_process = os.path.join(DATA_FOLDER, file_name) except json.JSONDecodeError: st.error("Error: The provided local file is not a valid JSON file.") st.stop() except Exception as e: st.error(f"An error occurred while processing the local file: {e}") st.stop() else: st.error(f"Error: The local file path '{local_file_path_input}' does not exist.") st.stop() if json_data is not None: try: # Load API keys and Hugging Face token from environment variables groq_api = os.environ.get("groq_api") hf_token = os.environ.get("hf_token") if not groq_api or not hf_token: st.error( "Error: API keys (GROQ_API_KEY and HF_TOKEN) not found in environment variables." ) st.info( "Please set the environment variables GROQ_API_KEY and HF_TOKEN." " You can do this in your terminal before running the script:\n" "`export GROQ_API_KEY='YOUR_GROQ_API_KEY'`\n" "`export HF_TOKEN='YOUR_HUGGINGFACE_TOKEN'`" ) st.stop() # Save the file to the data folder with open(file_path_to_process, "w") as f: json.dump(json_data, f, indent=4) # Save with indentation for readability st.success(f"File '{file_name}' successfully loaded and saved to:") st.code(file_path_to_process, language="plaintext") st.subheader("Task 1: Information Retrieval (Question-Answering)") if st.button("Process Data"): with st.spinner("Processing data..."): # Convert JSON data to texts and metadata texts = [item["docText"] for item in json_data] metadatas = [{"title": item["docTitle"], "date": item["docDate"]} for item in json_data] # Initialize the RecursiveCharacterTextSplitter splitter = RecursiveCharacterTextSplitter(chunk_size=700) docs = splitter.create_documents(texts=texts, metadatas=metadatas) # Initialize the HuggingFaceEmbeddings model embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") # Initialize the Chroma vector store vector_store = Chroma( collection_name="Patient_data", embedding_function=embeddings, persist_directory="./chroma_langchain_db", ) vector_store.add_documents(documents=docs) llm = ChatGroq(groq_api_key=groq_api, model_name="llama-3.1-8b-instant") st.session_state.llm = llm # Store llm for later use contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", """Given a chat history and the latest user question which might reference context in the chat history, formulate a standalone question which can be understood without the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) chat_history_store = {} def get_chat_session_history(session_id: str) -> BaseChatMessageHistory: if session_id not in chat_history_store: chat_history_store[session_id] = ChatMessageHistory() return chat_history_store[session_id] qa_prompt_template = ChatPromptTemplate.from_template(""" **Prompt:** **Context:** {context} **Question:** {input} **Instructions:** 1. **Carefully read and understand the provided context.** 2. **Think step-by-step to formulate a comprehensive and accurate answer.** 3. **Base your response solely on the given context.** 4. **Ensure the answer is clear, concise, and easy to understand.** 5. **Ensure the answer is in small understandable points with all content.** **Response:** [Your detailed and well-reasoned answer] **Note:** This prompt emphasizes careful consideration and accurate response based on the provided context. """) question_answer_chain = create_stuff_documents_chain(st.session_state.llm, qa_prompt_template) history_aware_retriever = create_history_aware_retriever( st.session_state.llm, vector_store.as_retriever( search_type="mmr", search_kwargs={'k': 10, 'fetch_k': 50} ), contextualize_q_prompt ) retrieval_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) conversational_rag_chain = RunnableWithMessageHistory( retrieval_chain, get_chat_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) st.session_state.conversational_rag_chain = conversational_rag_chain st.session_state.chat_history_store = chat_history_store st.success("Data processed! You can now ask questions and generate structured output.") if "conversational_rag_chain" in st.session_state: user_question = st.text_input("Ask a question about the data:", key="user_question") if user_question: session_id = "user_session" # You might want to make this dynamic for multiple users with st.spinner("Generating answer..."): response = st.session_state.conversational_rag_chain.invoke( {"input": user_question}, config={"configurable": {"session_id": session_id}}, ) st.markdown(response['answer']) st.subheader("Generate Structured Output") if st.button("Generate Structured Cancer Information"): with st.spinner("Generating structured output..."): json_data = json.loads(Path(file_path_to_process).read_text()) context = "" for item in json_data: context += json.dumps(item, indent=4) @dataclass class Stage: """Cancer Stage information.""" T: str = field(metadata={"description": "T Stage"}) N: str = field(metadata={"description": "N Stage"}) M: str = field(metadata={"description": "M Stage"}) group_stage: str = field(metadata={"description": "Group Stage"}) @dataclass class DiagnosisCharacteristic: """Primary cancer condition details.""" primary_cancer_condition: str = field(metadata={"description": "Primary cancer condition Example “Breast Cancer”, “Lung Cancer”, etc which given in patient data"}) diagnosis_date: str = field(metadata={"description": "Earliest date on which the cancer got confirmed Diagnosis date in MM-DD-YYYY format Example: How to Find: Typically in sentences such as “The biopsy on 01/12/2020 confirmed invasive ductal carcinoma.” or “Pathology Report (02/17/2020): Invasive breast cancer.” c. You may see multiple references to diagnosis across notes; pick the earliest one that specifically confirms the cancer."}) histology: List[str] = field(metadata={"description": """{Histological classification of the primary cancer condition, Describes the microscopic subtype of the tumor. Common examples: “Adenocarcinoma,” “Invasive ductal carcinoma,” “Squamous cell carcinoma,” etc. b. How to Find: In pathology reports or biopsy results. Terms like “Histologically consistent with adenocarcinoma” or “Invasive ductal carcinoma, Grade 2.”}"""}) stage: Stage = field(metadata={"description": """{Indicates Tumor size/extent. E.g., T2 means a moderate-sized tumor, T4 might mean a larger or invasive tumor. b. N: Indicates lymph Nodes involvement. N0 means no nodal involvement, N1/N2 means progressively more nodes involved. c. M: Indicates Metastasis. M0 means no distant spread; M1 means present. d. Group Stage: A single label (Stage I, Stage IIB, Stage IV, etc.) summarizing T, N, and M combined. e. How to Find: In imaging reports, pathology final reports, or physician notes, e.g. “Stage IIB (T2 N1 M0).” or “pT2 N1 M0.”}"""}) @dataclass class CancerRelatedMedication: """Cancer related medication details.""" medication_name: str = field(metadata={"description": "Medication for cancer:For example, “Doxorubicin,” “Cyclophosphamide,” “Paclitaxel,” “Trastuzumab,” “Pembrolizumab,” “Letrozole,” etc. "}) start_date: str = field(metadata={"description": "The earliest date this medication was started, in MM-DD-YYYY format, if available. Start date in MM-DD-YYYY format"}) end_date: str = field(metadata={"description": "The date the medication was stopped, if mentioned. If the patient is still on the medication, you may leave it blank or mark as nullEnd date in MM-DD-YYYY format"}) intent: str = field(metadata={"description": "A free-text field describing why the medication was given. Examples: “Adjuvant therapy post-surgery,” “Neoadjuvant therapy to shrink tumor,” “Maintenance therapy for HER2+ disease,” or “Hormonal therapy to block estrogen in ER+ cancer.”"}) @dataclass class CancerInformation: """Structured information about cancer diagnosis and medication.""" diagnosis_characteristics: List[DiagnosisCharacteristic] = field(metadata={"description": "List of primary cancers"}) cancer_related_medications: List[CancerRelatedMedication] = field(metadata={"description": "List of cancer related medication given to the patient"}) llm = ChatGroq(groq_api_key=groq_api, model_name="llama-3.1-8b-instant") structured_llm = llm.with_structured_output(CancerInformation) try: output = structured_llm.invoke(context) st.subheader("Task 2: Medical Data Extraction- Generated Structured Output:") st.json(output) # Save the generated output to a JSON file output_filename = f"{Path(file_path_to_process).stem}_structured.json" output_filepath = os.path.join(DATA_FOLDER, output_filename) with open(output_filepath, "w") as f: json.dump(output, f, indent=4) # Provide a download button with open(output_filepath, "rb") as f: st.download_button( label="Download Generated JSON", data=f, file_name=output_filename, mime="application/json", ) except Exception as e: st.error(f"Error generating structured output: {e}") except Exception as e: st.error(f"An unexpected error occurred: {e}") else: st.info("Please upload a JSON file or enter a local file path.") st.markdown("---") # Add a horizontal rule for visual separation st.markdown("[My linkedin](https://www.linkedin.com/in/darshankumarr/)") st.markdown("[Resume Link](https://drive.google.com/file/d/1HAL5NmUjT5bfa-NIgo-kVQ93-ISzGijh/view?usp=drive_link)")