Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files
UI.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
from engine import process_question
|
|
|
|
|
|
|
| 3 |
|
| 4 |
st.set_page_config(page_title="Hospital AI Assistant", layout="wide")
|
| 5 |
|
|
@@ -88,3 +90,30 @@ if user_input:
|
|
| 88 |
|
| 89 |
with st.chat_message("assistant"):
|
| 90 |
st.markdown(reply, unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from engine import process_question,download_transcript_txt,download_transcript_json
|
| 3 |
+
|
| 4 |
+
|
| 5 |
|
| 6 |
st.set_page_config(page_title="Hospital AI Assistant", layout="wide")
|
| 7 |
|
|
|
|
| 90 |
|
| 91 |
with st.chat_message("assistant"):
|
| 92 |
st.markdown(reply, unsafe_allow_html=True)
|
| 93 |
+
# =========================
|
| 94 |
+
# Download Conversation
|
| 95 |
+
# =========================
|
| 96 |
+
st.divider()
|
| 97 |
+
st.subheader("📥 Download Conversation")
|
| 98 |
+
|
| 99 |
+
col1, col2 = st.columns(2)
|
| 100 |
+
|
| 101 |
+
with col1:
|
| 102 |
+
if st.button("Download Transcript (TXT)"):
|
| 103 |
+
txt = download_transcript_txt()
|
| 104 |
+
st.download_button(
|
| 105 |
+
label="Download TXT",
|
| 106 |
+
data=txt,
|
| 107 |
+
file_name="chat_transcript.txt",
|
| 108 |
+
mime="text/plain"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
with col2:
|
| 112 |
+
if st.button("Download Transcript (JSON)"):
|
| 113 |
+
js = download_transcript_json()
|
| 114 |
+
st.download_button(
|
| 115 |
+
label="Download JSON",
|
| 116 |
+
data=js,
|
| 117 |
+
file_name="chat_transcript.json",
|
| 118 |
+
mime="application/json"
|
| 119 |
+
)
|
engine.py
CHANGED
|
@@ -3,6 +3,10 @@ import re
|
|
| 3 |
import sqlite3
|
| 4 |
from openai import OpenAI
|
| 5 |
from difflib import get_close_matches
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# =========================
|
| 8 |
# SETUP
|
|
@@ -623,6 +627,16 @@ def build_table_summary(table_name):
|
|
| 623 |
# =========================
|
| 624 |
|
| 625 |
def process_question(question):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
global LAST_PROMPT_TYPE, LAST_SUGGESTED_DATE
|
| 627 |
|
| 628 |
q = question.strip().lower()
|
|
@@ -720,12 +734,16 @@ def process_question(question):
|
|
| 720 |
try:
|
| 721 |
sql = call_llm(build_prompt(question))
|
| 722 |
except ValueError as e:
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
|
| 730 |
if sql == "NOT_ANSWERABLE":
|
| 731 |
return {
|
|
@@ -739,6 +757,13 @@ def process_question(question):
|
|
| 739 |
sql = correct_table_names(sql)
|
| 740 |
sql = validate_sql(sql)
|
| 741 |
cols, rows = run_query(sql)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 742 |
|
| 743 |
# ----------------------------------
|
| 744 |
# No data handling
|
|
@@ -755,6 +780,12 @@ def process_question(question):
|
|
| 755 |
}
|
| 756 |
|
| 757 |
if not rows:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
LAST_PROMPT_TYPE = "NO_DATA"
|
| 759 |
LAST_SUGGESTED_DATE = get_latest_data_date()
|
| 760 |
|
|
@@ -774,4 +805,26 @@ def process_question(question):
|
|
| 774 |
"columns": cols,
|
| 775 |
"data": rows
|
| 776 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
|
|
|
|
| 3 |
import sqlite3
|
| 4 |
from openai import OpenAI
|
| 5 |
from difflib import get_close_matches
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
TRANSCRIPT = []
|
| 9 |
+
|
| 10 |
|
| 11 |
# =========================
|
| 12 |
# SETUP
|
|
|
|
| 627 |
# =========================
|
| 628 |
|
| 629 |
def process_question(question):
|
| 630 |
+
|
| 631 |
+
def log_interaction(user_q, sql=None, result=None, error=None):
|
| 632 |
+
TRANSCRIPT.append({
|
| 633 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 634 |
+
"question": user_q,
|
| 635 |
+
"sql": sql,
|
| 636 |
+
"result_preview": result[:10] if isinstance(result, list) else result,
|
| 637 |
+
"error": error
|
| 638 |
+
})
|
| 639 |
+
|
| 640 |
global LAST_PROMPT_TYPE, LAST_SUGGESTED_DATE
|
| 641 |
|
| 642 |
q = question.strip().lower()
|
|
|
|
| 734 |
try:
|
| 735 |
sql = call_llm(build_prompt(question))
|
| 736 |
except ValueError as e:
|
| 737 |
+
log_interaction(
|
| 738 |
+
user_q=question,
|
| 739 |
+
error=str(e)
|
| 740 |
+
)
|
| 741 |
+
return {
|
| 742 |
+
"status": "ok",
|
| 743 |
+
"message": str(e),
|
| 744 |
+
"data": []
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
|
| 748 |
if sql == "NOT_ANSWERABLE":
|
| 749 |
return {
|
|
|
|
| 757 |
sql = correct_table_names(sql)
|
| 758 |
sql = validate_sql(sql)
|
| 759 |
cols, rows = run_query(sql)
|
| 760 |
+
|
| 761 |
+
log_interaction(
|
| 762 |
+
user_q=question,
|
| 763 |
+
sql=sql,
|
| 764 |
+
result=rows
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
|
| 768 |
# ----------------------------------
|
| 769 |
# No data handling
|
|
|
|
| 780 |
}
|
| 781 |
|
| 782 |
if not rows:
|
| 783 |
+
log_interaction(
|
| 784 |
+
user_q=question,
|
| 785 |
+
sql=sql,
|
| 786 |
+
result=[]
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
LAST_PROMPT_TYPE = "NO_DATA"
|
| 790 |
LAST_SUGGESTED_DATE = get_latest_data_date()
|
| 791 |
|
|
|
|
| 805 |
"columns": cols,
|
| 806 |
"data": rows
|
| 807 |
}
|
| 808 |
+
def download_transcript_json():
|
| 809 |
+
import json
|
| 810 |
+
return json.dumps(TRANSCRIPT, indent=2)
|
| 811 |
+
|
| 812 |
+
def download_transcript_txt():
|
| 813 |
+
lines = []
|
| 814 |
+
for i, entry in enumerate(TRANSCRIPT, 1):
|
| 815 |
+
lines.append(f"\n--- Query {i} ---")
|
| 816 |
+
lines.append(f"Time: {entry['timestamp']}")
|
| 817 |
+
lines.append(f"Question: {entry['question']}")
|
| 818 |
+
|
| 819 |
+
if entry.get("sql"):
|
| 820 |
+
lines.append(f"SQL: {entry['sql']}")
|
| 821 |
+
|
| 822 |
+
if entry.get("result_preview"):
|
| 823 |
+
lines.append(f"Result Preview: {entry['result_preview']}")
|
| 824 |
+
|
| 825 |
+
if entry.get("error"):
|
| 826 |
+
lines.append(f"Error: {entry['error']}")
|
| 827 |
+
|
| 828 |
+
return "\n".join(lines)
|
| 829 |
+
|
| 830 |
|