bhavika24 commited on
Commit
324600f
·
verified ·
1 Parent(s): 0c140fa

Upload 2 files

Browse files
Files changed (1) hide show
  1. engine.py +58 -27
engine.py CHANGED
@@ -136,17 +136,43 @@ def extract_relevant_tables(question, max_tables=4):
136
 
137
  # Build hints only for tables that actually exist
138
  hint_mappings = {
139
- "patient": ["patients"],
140
- "admission": ["admissions"],
141
- "visit": ["admissions", "icustays"],
142
- "icu": ["icustays", "chartevents"],
143
- "diagnosis": ["diagnoses_icd"],
144
- "procedure": ["procedures_icd"],
145
- "medication": ["prescriptions", "emar", "pharmacy"],
146
- "lab": ["labevents"],
147
- "vital": ["chartevents"],
148
- "stay": ["icustays"]
149
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  # Only include hints for tables that exist in the schema
152
  for intent, possible_tables in hint_mappings.items():
@@ -716,18 +742,24 @@ def process_question(question):
716
  # Unsupported question check
717
  # ----------------------------------
718
  if not is_question_supported(question):
719
- return {
720
- "status": "ok",
721
- "message": (
722
- "That information isn’t available in the system.\n\n"
723
- "You can ask about:\n"
724
- "• Patients\n"
725
- " Visits\n"
726
- " Conditions\n"
727
- " Medications"
728
- ),
729
- "data": []
730
- }
 
 
 
 
 
 
731
 
732
  # ----------------------------------
733
  # Generate SQL
@@ -746,6 +778,7 @@ def process_question(question):
746
  }
747
 
748
 
 
749
  if sql == "NOT_ANSWERABLE":
750
  return {
751
  "status": "ok",
@@ -822,13 +855,10 @@ def process_question(question):
822
  "columns": cols,
823
  "data": rows
824
  }
825
- def download_transcript_json():
826
- import json
827
- return json.dumps(TRANSCRIPT, indent=2)
828
 
829
  def download_transcript_txt():
830
  lines = []
831
- for i, entry in enumerate(TRANSCRIPT, 1):
832
  lines.append(f"\n--- Query {i} ---")
833
  lines.append(f"Time: {entry['timestamp']}")
834
  lines.append(f"Question: {entry['question']}")
@@ -845,3 +875,4 @@ def download_transcript_txt():
845
  return "\n".join(lines)
846
 
847
 
 
 
136
 
137
  # Build hints only for tables that actually exist
138
  hint_mappings = {
139
+ # Patients & visits
140
+ "patient": ["patients"],
141
+ "patients": ["patients"],
142
+
143
+ "admission": ["admissions"],
144
+ "admissions": ["admissions"],
145
+ "visit": ["admissions", "icustays"],
146
+ "visits": ["admissions", "icustays"],
147
+
148
+ # ICU
149
+ "icu": ["icustays", "chartevents"],
150
+ "stay": ["icustays"],
151
+ "stays": ["icustays"],
152
+
153
+ # Diagnoses / conditions
154
+ "diagnosis": ["diagnoses_icd"],
155
+ "diagnoses": ["diagnoses_icd"],
156
+ "condition": ["diagnoses_icd"],
157
+ "conditions": ["diagnoses_icd"],
158
+
159
+ # Procedures
160
+ "procedure": ["procedures_icd"],
161
+ "procedures": ["procedures_icd"],
162
+
163
+ # Medications
164
+ "medication": ["prescriptions", "emar", "pharmacy"],
165
+ "medications": ["prescriptions", "emar", "pharmacy"],
166
+ "drug": ["prescriptions"],
167
+ "drugs": ["prescriptions"],
168
+
169
+ # Labs & vitals
170
+ "lab": ["labevents"],
171
+ "labs": ["labevents"],
172
+ "vital": ["chartevents"],
173
+ "vitals": ["chartevents"],
174
+ }
175
+
176
 
177
  # Only include hints for tables that exist in the schema
178
  for intent, possible_tables in hint_mappings.items():
 
742
  # Unsupported question check
743
  # ----------------------------------
744
  if not is_question_supported(question):
745
+ log_interaction(
746
+ user_q=question,
747
+ error="Unsupported question"
748
+ )
749
+
750
+ return {
751
+ "status": "ok",
752
+ "message": (
753
+ "That information isn’t available in the system.\n\n"
754
+ "You can ask about:\n"
755
+ "• Patients\n"
756
+ "• Visits\n"
757
+ "• Conditions\n"
758
+ "• Medications"
759
+ ),
760
+ "data": []
761
+ }
762
+
763
 
764
  # ----------------------------------
765
  # Generate SQL
 
778
  }
779
 
780
 
781
+
782
  if sql == "NOT_ANSWERABLE":
783
  return {
784
  "status": "ok",
 
855
  "columns": cols,
856
  "data": rows
857
  }
 
 
 
858
 
859
  def download_transcript_txt():
860
  lines = []
861
+ for i, entry in enumerate(st.session_state.transcript, 1):
862
  lines.append(f"\n--- Query {i} ---")
863
  lines.append(f"Time: {entry['timestamp']}")
864
  lines.append(f"Question: {entry['question']}")
 
875
  return "\n".join(lines)
876
 
877
 
878
+