Triomics-app / app.py
Darshan03's picture
Update app.py
1bff81d verified
import streamlit as st
import json
import os
from datasets import load_dataset
from langchain_community.document_loaders import JSONLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from uuid import uuid4
from pathlib import Path
from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from IPython.display import Markdown, display
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List, Optional
from dataclasses import dataclass, field
# Define the data folder path
DATA_FOLDER = "data"
# Ensure the data folder exists
if not os.path.exists(DATA_FOLDER):
os.makedirs(DATA_FOLDER)
st.title("Triomics")
# Option to upload a file or provide a local file path
input_option = st.radio("Choose input method:", ("Upload a JSON file", "Autoload"))
uploaded_file = None
local_file_path_input = None
if input_option == "Upload a JSON file":
uploaded_file = st.file_uploader("Upload a JSON file", type=["json"])
elif input_option == "Autoload":
local_file_path_input = "1.json"
file_path_to_process = None
file_name = None
json_data = None
if uploaded_file is not None:
try:
json_data = json.load(uploaded_file)
file_name = uploaded_file.name
file_path_to_process = os.path.join(DATA_FOLDER, file_name)
except json.JSONDecodeError:
st.error("Error: The uploaded file is not a valid JSON file.")
st.stop()
except Exception as e:
st.error(f"An error occurred while processing the uploaded file: {e}")
st.stop()
elif local_file_path_input:
if os.path.exists(local_file_path_input):
try:
with open(local_file_path_input, 'r') as f:
json_data = json.load(f)
file_name = os.path.basename(local_file_path_input)
file_path_to_process = os.path.join(DATA_FOLDER, file_name)
except json.JSONDecodeError:
st.error("Error: The provided local file is not a valid JSON file.")
st.stop()
except Exception as e:
st.error(f"An error occurred while processing the local file: {e}")
st.stop()
else:
st.error(f"Error: The local file path '{local_file_path_input}' does not exist.")
st.stop()
if json_data is not None:
try:
# Load API keys and Hugging Face token from environment variables
groq_api = os.environ.get("groq_api")
hf_token = os.environ.get("hf_token")
if not groq_api or not hf_token:
st.error(
"Error: API keys (GROQ_API_KEY and HF_TOKEN) not found in environment variables."
)
st.info(
"Please set the environment variables GROQ_API_KEY and HF_TOKEN."
" You can do this in your terminal before running the script:\n"
"`export GROQ_API_KEY='YOUR_GROQ_API_KEY'`\n"
"`export HF_TOKEN='YOUR_HUGGINGFACE_TOKEN'`"
)
st.stop()
# Save the file to the data folder
with open(file_path_to_process, "w") as f:
json.dump(json_data, f, indent=4) # Save with indentation for readability
st.success(f"File '{file_name}' successfully loaded and saved to:")
st.code(file_path_to_process, language="plaintext")
st.subheader("Task 1: Information Retrieval (Question-Answering)")
if st.button("Process Data"):
with st.spinner("Processing data..."):
# Convert JSON data to texts and metadata
texts = [item["docText"] for item in json_data]
metadatas = [{"title": item["docTitle"], "date": item["docDate"]} for item in json_data]
# Initialize the RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(chunk_size=700)
docs = splitter.create_documents(texts=texts, metadatas=metadatas)
# Initialize the HuggingFaceEmbeddings model
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Initialize the Chroma vector store
vector_store = Chroma(
collection_name="Patient_data",
embedding_function=embeddings,
persist_directory="./chroma_langchain_db",
)
vector_store.add_documents(documents=docs)
llm = ChatGroq(groq_api_key=groq_api, model_name="llama-3.1-8b-instant")
st.session_state.llm = llm # Store llm for later use
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", """Given a chat history and the latest user question
which might reference context in the chat history, formulate a standalone question
which can be understood without the chat history. Do NOT answer the question,
just reformulate it if needed and otherwise return it as is."""),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
chat_history_store = {}
def get_chat_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in chat_history_store:
chat_history_store[session_id] = ChatMessageHistory()
return chat_history_store[session_id]
qa_prompt_template = ChatPromptTemplate.from_template("""
**Prompt:**
**Context:**
{context}
**Question:**
{input}
**Instructions:**
1. **Carefully read and understand the provided context.**
2. **Think step-by-step to formulate a comprehensive and accurate answer.**
3. **Base your response solely on the given context.**
4. **Ensure the answer is clear, concise, and easy to understand.**
5. **Ensure the answer is in small understandable points with all content.**
**Response:**
[Your detailed and well-reasoned answer]
**Note:** This prompt emphasizes careful consideration and accurate response based on the provided context.
""")
question_answer_chain = create_stuff_documents_chain(st.session_state.llm, qa_prompt_template)
history_aware_retriever = create_history_aware_retriever(
st.session_state.llm,
vector_store.as_retriever(
search_type="mmr",
search_kwargs={'k': 10, 'fetch_k': 50}
),
contextualize_q_prompt
)
retrieval_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
conversational_rag_chain = RunnableWithMessageHistory(
retrieval_chain,
get_chat_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
st.session_state.conversational_rag_chain = conversational_rag_chain
st.session_state.chat_history_store = chat_history_store
st.success("Data processed! You can now ask questions and generate structured output.")
if "conversational_rag_chain" in st.session_state:
user_question = st.text_input("Ask a question about the data:", key="user_question")
if user_question:
session_id = "user_session" # You might want to make this dynamic for multiple users
with st.spinner("Generating answer..."):
response = st.session_state.conversational_rag_chain.invoke(
{"input": user_question},
config={"configurable": {"session_id": session_id}},
)
st.markdown(response['answer'])
st.subheader("Generate Structured Output")
if st.button("Generate Structured Cancer Information"):
with st.spinner("Generating structured output..."):
json_data = json.loads(Path(file_path_to_process).read_text())
context = ""
for item in json_data:
context += json.dumps(item, indent=4)
@dataclass
class Stage:
"""Cancer Stage information."""
T: str = field(metadata={"description": "T Stage"})
N: str = field(metadata={"description": "N Stage"})
M: str = field(metadata={"description": "M Stage"})
group_stage: str = field(metadata={"description": "Group Stage"})
@dataclass
class DiagnosisCharacteristic:
"""Primary cancer condition details."""
primary_cancer_condition: str = field(metadata={"description": "Primary cancer condition Example “Breast Cancer”, “Lung Cancer”, etc which given in patient data"})
diagnosis_date: str = field(metadata={"description": "Earliest date on which the cancer got confirmed Diagnosis date in MM-DD-YYYY format Example: How to Find: Typically in sentences such as “The biopsy on 01/12/2020 confirmed invasive ductal carcinoma.” or “Pathology Report (02/17/2020): Invasive breast cancer.” c. You may see multiple references to diagnosis across notes; pick the earliest one that specifically confirms the cancer."})
histology: List[str] = field(metadata={"description": """{Histological classification of the primary cancer condition, Describes the microscopic subtype of the tumor. Common examples: “Adenocarcinoma,” “Invasive ductal carcinoma,” “Squamous cell carcinoma,” etc. b. How to Find: In pathology reports or biopsy results. Terms like “Histologically consistent with adenocarcinoma” or “Invasive ductal carcinoma, Grade 2.”}"""})
stage: Stage = field(metadata={"description": """{Indicates Tumor size/extent. E.g., T2 means a moderate-sized tumor, T4 might mean a larger or invasive tumor. b. N: Indicates lymph Nodes involvement. N0 means no nodal involvement, N1/N2 means progressively more nodes involved. c. M: Indicates Metastasis. M0 means no distant spread; M1 means present. d. Group Stage: A single label (Stage I, Stage IIB, Stage IV, etc.) summarizing T, N, and M combined. e. How to Find: In imaging reports, pathology final reports, or physician notes, e.g. “Stage IIB (T2 N1 M0).” or “pT2 N1 M0.”}"""})
@dataclass
class CancerRelatedMedication:
"""Cancer related medication details."""
medication_name: str = field(metadata={"description": "Medication for cancer:For example, “Doxorubicin,” “Cyclophosphamide,” “Paclitaxel,” “Trastuzumab,” “Pembrolizumab,” “Letrozole,” etc. "})
start_date: str = field(metadata={"description": "The earliest date this medication was started, in MM-DD-YYYY format, if available. Start date in MM-DD-YYYY format"})
end_date: str = field(metadata={"description": "The date the medication was stopped, if mentioned. If the patient is still on the medication, you may leave it blank or mark as nullEnd date in MM-DD-YYYY format"})
intent: str = field(metadata={"description": "A free-text field describing why the medication was given. Examples: “Adjuvant therapy post-surgery,” “Neoadjuvant therapy to shrink tumor,” “Maintenance therapy for HER2+ disease,” or “Hormonal therapy to block estrogen in ER+ cancer.”"})
@dataclass
class CancerInformation:
"""Structured information about cancer diagnosis and medication."""
diagnosis_characteristics: List[DiagnosisCharacteristic] = field(metadata={"description": "List of primary cancers"})
cancer_related_medications: List[CancerRelatedMedication] = field(metadata={"description": "List of cancer related medication given to the patient"})
llm = ChatGroq(groq_api_key=groq_api, model_name="llama-3.1-8b-instant")
structured_llm = llm.with_structured_output(CancerInformation)
try:
output = structured_llm.invoke(context)
st.subheader("Task 2: Medical Data Extraction- Generated Structured Output:")
st.json(output)
# Save the generated output to a JSON file
output_filename = f"{Path(file_path_to_process).stem}_structured.json"
output_filepath = os.path.join(DATA_FOLDER, output_filename)
with open(output_filepath, "w") as f:
json.dump(output, f, indent=4)
# Provide a download button
with open(output_filepath, "rb") as f:
st.download_button(
label="Download Generated JSON",
data=f,
file_name=output_filename,
mime="application/json",
)
except Exception as e:
st.error(f"Error generating structured output: {e}")
except Exception as e:
st.error(f"An unexpected error occurred: {e}")
else:
st.info("Please upload a JSON file or enter a local file path.")
st.markdown("---") # Add a horizontal rule for visual separation
st.markdown("[My linkedin](https://www.linkedin.com/in/darshankumarr/)")
st.markdown("[Resume Link](https://drive.google.com/file/d/1HAL5NmUjT5bfa-NIgo-kVQ93-ISzGijh/view?usp=drive_link)")