Darshan03 commited on
Commit
b71faa5
·
verified ·
1 Parent(s): 6d41a2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -4
app.py CHANGED
@@ -17,6 +17,10 @@ from langchain_community.chat_message_histories import ChatMessageHistory
17
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
18
  from langchain.chains.combine_documents import create_stuff_documents_chain
19
  from IPython.display import Markdown, display
 
 
 
 
20
 
21
  # Define the data folder path
22
  DATA_FOLDER = "data"
@@ -86,7 +90,8 @@ if uploaded_file is not None:
86
  )
87
  vector_store.add_documents(documents=docs)
88
 
89
- llm = ChatGroq(groq_api_key=groq_api, model_name="llama-3.3-70b-versatile")
 
90
 
91
  contextualize_q_prompt = ChatPromptTemplate.from_messages(
92
  [
@@ -132,10 +137,10 @@ just reformulate it if needed and otherwise return it as is."""),
132
  **Note:** This prompt emphasizes careful consideration and accurate response based on the provided context.
133
  """)
134
 
135
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt_template)
136
 
137
  history_aware_retriever = create_history_aware_retriever(
138
- llm,
139
  vector_store.as_retriever(
140
  search_type="mmr",
141
  search_kwargs={'k': 10, 'fetch_k': 50}
@@ -155,7 +160,7 @@ just reformulate it if needed and otherwise return it as is."""),
155
 
156
  st.session_state.conversational_rag_chain = conversational_rag_chain
157
  st.session_state.chat_history_store = chat_history_store
158
- st.success("Data processed! You can now ask questions.")
159
 
160
  if "conversational_rag_chain" in st.session_state:
161
  user_question = st.text_input("Ask a question about the data:", key="user_question")
@@ -168,6 +173,68 @@ just reformulate it if needed and otherwise return it as is."""),
168
  )
169
  st.markdown(response['answer'])
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  except json.JSONDecodeError:
172
  st.error("Error: The uploaded file is not a valid JSON file.")
173
  except Exception as e:
 
17
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
18
  from langchain.chains.combine_documents import create_stuff_documents_chain
19
  from IPython.display import Markdown, display
20
+ from langchain_core.output_parsers import JsonOutputParser
21
+ from langchain_core.pydantic_v1 import BaseModel, Field
22
+ from typing import List, Optional
23
+ from dataclasses import dataclass, field
24
 
25
  # Define the data folder path
26
  DATA_FOLDER = "data"
 
90
  )
91
  vector_store.add_documents(documents=docs)
92
 
93
+ llm = ChatGroq(groq_api_key=groq_api, model_name="llama-3.1-8b-instant")
94
+ st.session_state.llm = llm # Store llm for later use
95
 
96
  contextualize_q_prompt = ChatPromptTemplate.from_messages(
97
  [
 
137
  **Note:** This prompt emphasizes careful consideration and accurate response based on the provided context.
138
  """)
139
 
140
+ question_answer_chain = create_stuff_documents_chain(st.session_state.llm, qa_prompt_template)
141
 
142
  history_aware_retriever = create_history_aware_retriever(
143
+ st.session_state.llm,
144
  vector_store.as_retriever(
145
  search_type="mmr",
146
  search_kwargs={'k': 10, 'fetch_k': 50}
 
160
 
161
  st.session_state.conversational_rag_chain = conversational_rag_chain
162
  st.session_state.chat_history_store = chat_history_store
163
+ st.success("Data processed! You can now ask questions and generate structured output.")
164
 
165
  if "conversational_rag_chain" in st.session_state:
166
  user_question = st.text_input("Ask a question about the data:", key="user_question")
 
173
  )
174
  st.markdown(response['answer'])
175
 
176
+ st.subheader("Generate Structured Output")
177
+ if st.button("Generate Structured Cancer Information"):
178
+ with st.spinner("Generating structured output..."):
179
+ json_data = json.loads(Path(file_path).read_text())
180
+ context = ""
181
+ for item in json_data:
182
+ context += json.dumps(item, indent=4)
183
+
184
+ @dataclass
185
+ class Stage:
186
+ """Cancer Stage information."""
187
+ T: str = field(metadata={"description": "T Stage"})
188
+ N: str = field(metadata={"description": "N Stage"})
189
+ M: str = field(metadata={"description": "M Stage"})
190
+ group_stage: str = field(metadata={"description": "Group Stage"})
191
+
192
+ @dataclass
193
+ class DiagnosisCharacteristic:
194
+ """Primary cancer condition details."""
195
+ primary_cancer_condition: str = field(metadata={"description": "Primary cancer condition Example “Breast Cancer”, “Lung Cancer”, etc which given in patient data"})
196
+ diagnosis_date: str = field(metadata={"description": "Earliest date on which the cancer got confirmed Diagnosis date in MM-DD-YYYY format Example: How to Find: Typically in sentences such as “The biopsy on 01/12/2020 confirmed invasive ductal carcinoma.” or “Pathology Report (02/17/2020): Invasive breast cancer.” c. You may see multiple references to diagnosis across notes; pick the earliest one that specifically confirms the cancer."})
197
+ histology: List[str] = field(metadata={"description": """{Histological classification of the primary cancer condition, Describes the microscopic subtype of the tumor. Common examples: “Adenocarcinoma,” “Invasive ductal carcinoma,” “Squamous cell carcinoma,” etc. b. How to Find: In pathology reports or biopsy results. Terms like “Histologically consistent with adenocarcinoma” or “Invasive ductal carcinoma, Grade 2.”}"""})
198
+ stage: Stage = field(metadata={"description": """{Indicates Tumor size/extent. E.g., T2 means a moderate-sized tumor, T4 might mean a larger or invasive tumor. b. N: Indicates lymph Nodes involvement. N0 means no nodal involvement, N1/N2 means progressively more nodes involved. c. M: Indicates Metastasis. M0 means no distant spread; M1 means present. d. Group Stage: A single label (Stage I, Stage IIB, Stage IV, etc.) summarizing T, N, and M combined. e. How to Find: In imaging reports, pathology final reports, or physician notes, e.g. “Stage IIB (T2 N1 M0).” or “pT2 N1 M0.”}"""})
199
+
200
+ @dataclass
201
+ class CancerRelatedMedication:
202
+ """Cancer related medication details."""
203
+ medication_name: str = field(metadata={"description": "Medication for cancer:For example, “Doxorubicin,” “Cyclophosphamide,” “Paclitaxel,” “Trastuzumab,” “Pembrolizumab,” “Letrozole,” etc. "})
204
+ start_date: str = field(metadata={"description": "The earliest date this medication was started, in MM-DD-YYYY format, if available. Start date in MM-DD-YYYY format"})
205
+ end_date: str = field(metadata={"description": "The date the medication was stopped, if mentioned. If the patient is still on the medication, you may leave it blank or mark as nullEnd date in MM-DD-YYYY format"})
206
+ intent: str = field(metadata={"description": "A free-text field describing why the medication was given. Examples: “Adjuvant therapy post-surgery,” “Neoadjuvant therapy to shrink tumor,” “Maintenance therapy for HER2+ disease,” or “Hormonal therapy to block estrogen in ER+ cancer.”"})
207
+
208
+ @dataclass
209
+ class CancerInformation:
210
+ """Structured information about cancer diagnosis and medication."""
211
+ diagnosis_characteristics: List[DiagnosisCharacteristic] = field(metadata={"description": "List of primary cancers"})
212
+ cancer_related_medications: List[CancerRelatedMedication] = field(metadata={"description": "List of cancer related medication given to the patient"})
213
+
214
+ structured_llm = st.session_state.llm.with_structured_output(CancerInformation)
215
+ try:
216
+ output = structured_llm.invoke(context)
217
+ st.subheader("Generated Structured Output:")
218
+ st.json(output.dict())
219
+
220
+ # Save the generated output to a JSON file
221
+ output_filename = f"{Path(file_path).stem}_structured.json"
222
+ output_filepath = os.path.join(DATA_FOLDER, output_filename)
223
+ with open(output_filepath, "w") as f:
224
+ json.dump(output.dict(), f, indent=4)
225
+
226
+ # Provide a download button
227
+ with open(output_filepath, "rb") as f:
228
+ st.download_button(
229
+ label="Download Generated JSON",
230
+ data=f,
231
+ file_name=output_filename,
232
+ mime="application/json",
233
+ )
234
+
235
+ except Exception as e:
236
+ st.error(f"Error generating structured output: {e}")
237
+
238
  except json.JSONDecodeError:
239
  st.error("Error: The uploaded file is not a valid JSON file.")
240
  except Exception as e: