File size: 14,414 Bytes
a64746a
fb735c7
 
14e0cc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b71faa5
 
 
 
fb735c7
 
 
 
 
 
 
 
1bff81d
fb735c7
c8dc87e
8c6700c
c8dc87e
 
 
 
 
 
6ed0f36
159e273
c8dc87e
 
 
 
fb735c7
 
c8dc87e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb735c7
14e0cc0
6d41a2c
 
14e0cc0
 
 
 
 
 
 
 
 
 
 
 
 
fb735c7
c8dc87e
14e0cc0
fb735c7
c8dc87e
 
fb735c7
e13955f
14e0cc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b71faa5
 
14e0cc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b71faa5
14e0cc0
 
b71faa5
14e0cc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b71faa5
14e0cc0
 
 
 
 
 
 
 
 
 
 
 
b71faa5
 
 
c8dc87e
b71faa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8dc87e
5ff9031
e08fc01
b71faa5
 
e13955f
dee086a
b71faa5
 
c8dc87e
b71faa5
 
461d4de
b71faa5
 
 
 
 
 
 
 
 
 
 
 
fb735c7
c8dc87e
14e0cc0
1bff81d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import streamlit as st
import json
import os
from datasets import load_dataset
from langchain_community.document_loaders import JSONLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from uuid import uuid4
from pathlib import Path
from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from IPython.display import Markdown, display
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List, Optional
from dataclasses import dataclass, field

# Define the data folder path
DATA_FOLDER = "data"

# Ensure the data folder exists
if not os.path.exists(DATA_FOLDER):
    os.makedirs(DATA_FOLDER)

st.title("Triomics")

# Option to upload a file or provide a local file path
input_option = st.radio("Choose input method:", ("Upload a JSON file", "Autoload"))

uploaded_file = None
local_file_path_input = None

if input_option == "Upload a JSON file":
    uploaded_file = st.file_uploader("Upload a JSON file", type=["json"])
elif input_option == "Autoload":
    local_file_path_input = "1.json"

file_path_to_process = None
file_name = None
json_data = None

if uploaded_file is not None:
    try:
        json_data = json.load(uploaded_file)
        file_name = uploaded_file.name
        file_path_to_process = os.path.join(DATA_FOLDER, file_name)
    except json.JSONDecodeError:
        st.error("Error: The uploaded file is not a valid JSON file.")
        st.stop()
    except Exception as e:
        st.error(f"An error occurred while processing the uploaded file: {e}")
        st.stop()
elif local_file_path_input:
    if os.path.exists(local_file_path_input):
        try:
            with open(local_file_path_input, 'r') as f:
                json_data = json.load(f)
            file_name = os.path.basename(local_file_path_input)
            file_path_to_process = os.path.join(DATA_FOLDER, file_name)
        except json.JSONDecodeError:
            st.error("Error: The provided local file is not a valid JSON file.")
            st.stop()
        except Exception as e:
            st.error(f"An error occurred while processing the local file: {e}")
            st.stop()
    else:
        st.error(f"Error: The local file path '{local_file_path_input}' does not exist.")
        st.stop()

if json_data is not None:
    try:
        # Load API keys and Hugging Face token from environment variables
        groq_api = os.environ.get("groq_api")
        hf_token = os.environ.get("hf_token")

        if not groq_api or not hf_token:
            st.error(
                "Error: API keys (GROQ_API_KEY and HF_TOKEN) not found in environment variables."
            )
            st.info(
                "Please set the environment variables GROQ_API_KEY and HF_TOKEN."
                " You can do this in your terminal before running the script:\n"
                "`export GROQ_API_KEY='YOUR_GROQ_API_KEY'`\n"
                "`export HF_TOKEN='YOUR_HUGGINGFACE_TOKEN'`"
            )
            st.stop()

        # Save the file to the data folder
        with open(file_path_to_process, "w") as f:
            json.dump(json_data, f, indent=4)  # Save with indentation for readability

        st.success(f"File '{file_name}' successfully loaded and saved to:")
        st.code(file_path_to_process, language="plaintext")

        st.subheader("Task 1: Information Retrieval (Question-Answering)")

        if st.button("Process Data"):
            with st.spinner("Processing data..."):
                # Convert JSON data to texts and metadata
                texts = [item["docText"] for item in json_data]
                metadatas = [{"title": item["docTitle"], "date": item["docDate"]} for item in json_data]

                # Initialize the RecursiveCharacterTextSplitter
                splitter = RecursiveCharacterTextSplitter(chunk_size=700)
                docs = splitter.create_documents(texts=texts, metadatas=metadatas)

                # Initialize the HuggingFaceEmbeddings model
                embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

                # Initialize the Chroma vector store
                vector_store = Chroma(
                    collection_name="Patient_data",
                    embedding_function=embeddings,
                    persist_directory="./chroma_langchain_db",
                )
                vector_store.add_documents(documents=docs)

                llm = ChatGroq(groq_api_key=groq_api, model_name="llama-3.1-8b-instant")
                st.session_state.llm = llm # Store llm for later use

                contextualize_q_prompt = ChatPromptTemplate.from_messages(
                    [
                        ("system", """Given a chat history and the latest user question
which might reference context in the chat history, formulate a standalone question
which can be understood without the chat history. Do NOT answer the question,
just reformulate it if needed and otherwise return it as is."""),
                        MessagesPlaceholder("chat_history"),
                        ("human", "{input}"),
                    ]
                )

                chat_history_store = {}

                def get_chat_session_history(session_id: str) -> BaseChatMessageHistory:
                    if session_id not in chat_history_store:
                        chat_history_store[session_id] = ChatMessageHistory()
                    return chat_history_store[session_id]

                qa_prompt_template = ChatPromptTemplate.from_template("""
**Prompt:**

**Context:**

{context}

**Question:**

{input}

**Instructions:**

1. **Carefully read and understand the provided context.**
2. **Think step-by-step to formulate a comprehensive and accurate answer.**
3. **Base your response solely on the given context.**
4. **Ensure the answer is clear, concise, and easy to understand.**
5. **Ensure the answer is in small understandable points with all content.**

**Response:**

[Your detailed and well-reasoned answer]

**Note:** This prompt emphasizes careful consideration and accurate response based on the provided context.
""")

                question_answer_chain = create_stuff_documents_chain(st.session_state.llm, qa_prompt_template)

                history_aware_retriever = create_history_aware_retriever(
                    st.session_state.llm,
                    vector_store.as_retriever(
                        search_type="mmr",
                        search_kwargs={'k': 10, 'fetch_k': 50}
                    ),
                    contextualize_q_prompt
                )

                retrieval_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

                conversational_rag_chain = RunnableWithMessageHistory(
                    retrieval_chain,
                    get_chat_session_history,
                    input_messages_key="input",
                    history_messages_key="chat_history",
                    output_messages_key="answer",
                )

                st.session_state.conversational_rag_chain = conversational_rag_chain
                st.session_state.chat_history_store = chat_history_store
                st.success("Data processed! You can now ask questions and generate structured output.")

        if "conversational_rag_chain" in st.session_state:
            user_question = st.text_input("Ask a question about the data:", key="user_question")
            if user_question:
                session_id = "user_session"  # You might want to make this dynamic for multiple users
                with st.spinner("Generating answer..."):
                    response = st.session_state.conversational_rag_chain.invoke(
                        {"input": user_question},
                        config={"configurable": {"session_id": session_id}},
                    )
                    st.markdown(response['answer'])

        st.subheader("Generate Structured Output")
        if st.button("Generate Structured Cancer Information"):
            with st.spinner("Generating structured output..."):
                json_data = json.loads(Path(file_path_to_process).read_text())
                context = ""
                for item in json_data:
                    context += json.dumps(item, indent=4)

                @dataclass
                class Stage:
                    """Cancer Stage information."""
                    T: str = field(metadata={"description": "T Stage"})
                    N: str = field(metadata={"description": "N Stage"})
                    M: str = field(metadata={"description": "M Stage"})
                    group_stage: str = field(metadata={"description": "Group Stage"})

                @dataclass
                class DiagnosisCharacteristic:
                    """Primary cancer condition details."""
                    primary_cancer_condition: str = field(metadata={"description": "Primary cancer condition Example “Breast Cancer”, “Lung Cancer”, etc which given in patient data"})
                    diagnosis_date: str = field(metadata={"description": "Earliest date on which the cancer got confirmed Diagnosis date in MM-DD-YYYY format Example:  How to Find: Typically in sentences such as “The biopsy on 01/12/2020 confirmed invasive ductal carcinoma.” or “Pathology Report (02/17/2020): Invasive breast cancer.” c. You may see multiple references to diagnosis across notes; pick the earliest one that specifically confirms the cancer."})
                    histology: List[str] = field(metadata={"description": """{Histological classification of the primary cancer condition, Describes the microscopic subtype of the tumor. Common examples: “Adenocarcinoma,” “Invasive ductal carcinoma,” “Squamous cell carcinoma,” etc. b. How to Find: In pathology reports or biopsy results. Terms like “Histologically consistent with adenocarcinoma” or “Invasive ductal carcinoma, Grade 2.”}"""})
                    stage: Stage = field(metadata={"description": """{Indicates Tumor size/extent. E.g., T2 means a moderate-sized tumor, T4 might mean a larger or invasive tumor. b. N: Indicates lymph Nodes involvement. N0 means no nodal involvement, N1/N2 means progressively more nodes involved. c. M: Indicates Metastasis. M0 means no distant spread; M1 means present. d. Group Stage: A single label (Stage I, Stage IIB, Stage IV, etc.) summarizing T, N, and M combined. e. How to Find: In imaging reports, pathology final reports, or physician notes, e.g. “Stage IIB (T2 N1 M0).” or “pT2 N1 M0.”}"""})

                @dataclass
                class CancerRelatedMedication:
                    """Cancer related medication details."""
                    medication_name: str = field(metadata={"description": "Medication for cancer:For example, “Doxorubicin,” “Cyclophosphamide,” “Paclitaxel,” “Trastuzumab,” “Pembrolizumab,” “Letrozole,” etc. "})
                    start_date: str = field(metadata={"description": "The earliest date this medication was started, in MM-DD-YYYY format, if available. Start date in MM-DD-YYYY format"})
                    end_date: str = field(metadata={"description": "The date the medication was stopped, if mentioned. If the patient is still on the medication, you may leave it blank or mark as nullEnd date in MM-DD-YYYY format"})
                    intent: str = field(metadata={"description": "A free-text field describing why the medication was given. Examples: “Adjuvant therapy post-surgery,” “Neoadjuvant therapy to shrink tumor,” “Maintenance therapy for HER2+ disease,” or “Hormonal therapy to block estrogen in ER+ cancer.”"})

                @dataclass
                class CancerInformation:
                    """Structured information about cancer diagnosis and medication."""
                    diagnosis_characteristics: List[DiagnosisCharacteristic] = field(metadata={"description": "List of primary cancers"})
                    cancer_related_medications: List[CancerRelatedMedication] = field(metadata={"description": "List of cancer related medication given to the patient"})

                llm = ChatGroq(groq_api_key=groq_api, model_name="llama-3.1-8b-instant")
                structured_llm = llm.with_structured_output(CancerInformation)
                try:
                    output = structured_llm.invoke(context)
                    st.subheader("Task 2: Medical Data Extraction- Generated Structured Output:")
                    st.json(output)

                    # Save the generated output to a JSON file
                    output_filename = f"{Path(file_path_to_process).stem}_structured.json"
                    output_filepath = os.path.join(DATA_FOLDER, output_filename)
                    with open(output_filepath, "w") as f:
                        json.dump(output, f, indent=4)

                    # Provide a download button
                    with open(output_filepath, "rb") as f:
                        st.download_button(
                            label="Download Generated JSON",
                            data=f,
                            file_name=output_filename,
                            mime="application/json",
                        )

                except Exception as e:
                    st.error(f"Error generating structured output: {e}")
    except Exception as e:
        st.error(f"An unexpected error occurred: {e}")
else:
    st.info("Please upload a JSON file or enter a local file path.")

st.markdown("---")  # Add a horizontal rule for visual separation
st.markdown("[My linkedin](https://www.linkedin.com/in/darshankumarr/)")
st.markdown("[Resume Link](https://drive.google.com/file/d/1HAL5NmUjT5bfa-NIgo-kVQ93-ISzGijh/view?usp=drive_link)")