selfservicedemo / app.py
Benedette's picture
Update app.py
b8e062a verified
from sqlalchemy import create_engine
import os
#from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
#import models
os.environ["DB_PASSWORD"]
os.environ["DB_HOST"]
os.environ["DB_PORT"]
os.environ["DB"]
# Construct the connection string
SQL_DATABASE_URL = f'mssql+pymssql://Benedette:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB}'
# Create an engine instance
engine = create_engine(
SQL_DATABASE_URL, connect_args={}, echo=True
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
#linelist_factart_schema = models.t_Linelist_FACTART
# Base = declarative_base()
db = SessionLocal()
# linelist_factart_schema
# a wrapper around the SQLAlchemy engine to interact with a SQL database.
from llama_index.core import SQLDatabase
# sql_database = SQLDatabase(engine)
tables = ["Linelist_FACTART","LineListTransHTS", "LinelistPrep","LinelistHEI", "AggregateDSD","AggregateOTZEligibilityAndEnrollments","AggregateARTHistory"]
sql_database = SQLDatabase(engine, include_tables=tables)
sql_database
os.environ["OPENAI_API_KEY"]
from llama_index.llms.openai import OpenAI
llm = OpenAI(temperature=0, model="gpt-3.5-turbo")
fact_linelist_str = (
"A client-level linelist that contains comprehensive data on all clients who have ever received treatment for HIV/AIDS, "
"encompassing various indicators and clinical parameters."
"Use this table to answer quetions related to active patients currently on treatment/txcurr, viral load results"
"Key attributes captured in this linelist include:"
"if input in the NUPI column is NULL then the client has no NUPI"
"Active patients is where ARTOutcomeDescription is Active."
"Clinical indicators like Last CD4 count, Last VL (Viral Load), and WHO Stage, aiding in the assessment of disease progression and treatment response."
"Demographic information, including Age at ART Start, Gender, Marital Status, and County/Sub-County, facilitating analysis of patient demographics."
"Medical history and co-morbidities, such as Diabetes and Hypertension status, providing context on underlying health conditions and associated risk factors."
"Facility-related data, such as Facility Name, Site Code,County, and Partner Name, enabling assessment of service delivery across different healthcare facilities and implementing partners."
"Pregnancy-related indicators, including Pregnant ART Start and Pregnant at Enrollment, supporting maternal and child health monitoring and intervention."
"clients are uniquley identified by concatenating PatientPKHash and Sitecode"
"LastVL is the most current VL for the client"
"LowViremia/suppressed is when a client viral load is less that 200 copies per ml,1= True 0 = False "
"HighViremia/unsuppressed is when a client viral load is more that 200 copies per ml,1= True 0 = False "
"HasValidVL is boolean value for if client has a valid VL"
"ISTxCurs indicates whether the patients are active on teatment where 1= True"
"Treatment Outcomes(ARTOutcomeDescription) is as at one point"
"This table can be used to answer queries such as:"
"What is the distribution of treatment outcomes among HIV/AIDS patients, such as Active, Transfer Out, and Loss to Follow-Up by county, partner,age?"
"What proportion of patients have achieved viral suppression, as indicated by their Last VL results by coounty?"
"What percentage of HIV/AIDS patients have co-morbid conditions such as diabetes or hypertension?"
)
hts_linelist_str = (
"A client-level linelist containing comprehensive HIV testing data for all adult clients (> 18 years) who have undergone HIV testing."
"This dataset captures a wide range of information including demographic details, testing outcomes, testing history, and programmatic indicators."
"It serves as a valuable table for analyzing HIV testing patterns, testing outcomes, and testing strategies among adult populations."
"Please note that this dataset is not suitable for inquiries related to patients on treatment."
"use this table to answer any questions related to HIV testing"
"Additional details available include:"
"- Age at Testing (AgeAtTesting): Age of the client at the time of HIV testing."
"- Age Group (AgeGroup): Categorization of clients into 4-year age bands from 1 to 64 years."
"- Agency Name (AgencyName): Name of the funding body or organization supporting the testing program."
"- Client Self-Tested (ClientSelfTested): Indicates whether a client has ever performed self-testing for HIV."
"- Client Tested As (ClientTestedAs): Categorizes clients based on whether they were tested individually or as part of a couple."
"- County (County) and Sub-County (SubCounty): Geographic location of the testing facility."
"- Couple Discordant (CoupleDiscordant): Indicates whether a couple tested together was concordant or discordant for HIV."
"- Date of Birth (DOB): Date of birth of the client."
"- Enrollment Date (EnrollmentDate): Date when the client was enrolled into the CCC."
"- Entry Point (EntryPoint): Service point where the HIV test was conducted (e.g., VCT, OPD)."
"- Ever Tested for HIV (EverTestedForHiv): Indicates whether the client has ever been tested for HIV before."
"- Facility Name (FacilityName) and MFL Code (MFLCode): Name and code of the testing facility."
"- Final Test Result (FinalTestResult): Result of the HIV test for the encounter."
"- Gender (Gender) and Marital Status (MaritalStatus): Demographic characteristics of the client."
"- Linked (Linked): Boolean value indicating whether the client was successfully linked to follow-up services."
"- Months Since Last Test (MonthsSinceLastTest): Number of months since the client's last HIV test."
"- Test Date (TestDate): Date when the client was tested for HIV."
"- Test Strategy (TestStrategy): Strategy employed for HIV testing (e.g., Hospital Patient, Non-Patient)."
"- Test Type (TestType): Type of HIV test conducted during the encounter."
"- Tested (Tested): Boolean value indicating whether the client was tested for HIV."
"- Tested Before (TestedBefore): Indicates if the client has been tested for HIV within the last 12 months."
"- TB Screening (tbScreening): Outcome of TB screening conducted during the encounter."
"Positivity rate is number of positive tests from all the test conducted in a certain period"
)
prep_str = (
"A client-level line list containing comprehensive information on all clients enrolled in Pre-Exposure Prophylaxis (PrEP) programs."
"Additional information available includes:"
"- As of Date (AsofDate): End of the reporting month for the data."
"- Assessment Month (AssessmentMonth) and Assessment Year (AssessmentYear): Month and year when the client was assessed for PrEP enrollment."
"- Eligible for PrEP (EligiblePrep): Boolean value indicating whether the client is eligible for enrollment in PrEP based on risk category."
"- Latest HIV Risk Category (LatestHIVRiskCategory): Last recorded risk category from the HIV testing machine learning model."
"- Screened for PrEP (ScreenedPrep): Boolean value indicating whether the client was assessed for enrollment into PrEP."
"- PatientPKHash: Hashed value representing the unique client ID in the specific facility."
"Use this table to answer any prep related question,i.e from high risk clients how many were enrolled in Prep"
)
hei_str= ("A client level linelist that contains various indicators of HIV-exposed infants"
"this table should be used for any HEI related questions,Iincluding whether HEI is breastfeeding,tested at different timepoints, "
"outcome of Hei after they exit the HEI program"
"Additional information available includes:"
"-BF12mnths Indicates whether the HEI is breastfeeding at 12 months of age as of last cwc visit"
"-BF18mnths Indicates whether the HEI is breastfeeding at 18 months of age as of last cwc visit"
"-EBF6mnths Indicates whether the HEI is using Exclusive Replacement(ERF) feeding method at 6 months of age as of last cwc visit"
"-HEIExitCriteria the Exit reason for an exposed infant after 24months"
"-InitialPCRBtwn8wks_12mnthsIndicates whether the HEI's DNAPCR1 was done at age of between 8 weeks and 48 weeks"
"-TestedAt12months-Indicates whether the HEI's DNAPCR2 was done at age of 12 months of age")
otzenroll_str=("An aggregate table that contains counts of TXCurr/number of active individuals between 10 and 19 years who are eligible for OTZ program, /enrolled in OTZ,"
"completed training modules and eligible for VL"
"Use this table for addressing any inquiries regarding OTZ and corresponding viral loads"
"AgeGroup: A 4-year age band from 1 to 64 years"
"CompletedToday_OTZ_Beyond: Has the client completed OTZ_Beyond today"
"CompletedToday_OTZ_Leadership: Has the client completed OTZ_Leadership today"
"CompletedToday_OTZ_MakingDecisions: Has the client completed OTZ_MakingDecisions today"
"CompletedToday_OTZ_Orientation: Has the client completed OTZ_Orientation today"
"CompletedToday_OTZ_Participation: Has the client completed OTZ_Participation today"
"CompletedToday_OTZ_SRH: Has the client completed OTZ_SRH today"
"CompletedToday_OTZ_Transition: Has the client completed OTZ_Transition today"
"CompletedToday_OTZ_TreatmentLiteracy: Has the client completed OTZ_TreatmentLiteracy today"
"CompletedTraining: Number of clients who have completed OTZ modules training"
"County: The County where the facility is located"
"EligibleVL: Is the client eligible for a viral load"
"Enrolled: Number of clients enrolled into OTZ program"
"FacilityName: The facility name as entered in KHMFL"
"FirstVL: The first ever documented viral load"
"Gender: Sex of the patient"
"HasValidVL: Does the client have a valid viral load"
"LastVL: This is the most current Viral load for the client -"
"LoadDate: Date when the dataset was ETL loaded"
"MFLCode: Master facility code as assigned in the KHMFL"
"ModulesPreviouslyCovered: Modules that the client has covered before this visit"
"OTZEnrollmentYearMonth: The year and the month the client was enrolled in OTZ program"
"PartnerName: The implementing partner mechanism"
"SubCounty: The Sub County where the facility is located"
"TransferInStatus: Did the client transfer in"
"ValidVLResult: The VL result that is within 12 months from the reporting period taking into account age group validity"
"ValidVLResultCategory: The viral load results categorizations as LDL, High-risk LLV, Low-risk LLV, and unsuppressed"
"patients_eligible: Number of clients eligible for enrollment into OTZ program")
aggtxcurr_str= (
"An aggregate dataset containing counts of active number of patients/TxCurr for each facility at each month, disaggregated by various indicators."
"Query this table To identify increase/decrease the total number of active patients at overtime"
"Number of acive patients or treatment is calculated at end of the month "
"AsofDateKey:the End of month reporting date (format = yyyy-mm-dd), use this date to extract number of active client as at that month "
"DATIMAgeGroup: The DATIM Age disaggregations"
#"NumofPatients The total number of active patients/TXcurr"
"isTxCurr: The total number of active patients/TXcurr"
)
from llama_index.core.objects import (
SQLTableNodeMapping,
ObjectIndex,
SQLTableSchema,
)
from llama_index.core import VectorStoreIndex
#store the table schema in an index
table_node_mapping = SQLTableNodeMapping(sql_database)
#store schema information for each table.
table_schema_objs = [
(SQLTableSchema(table_name="Linelist_FACTART", context_str=fact_linelist_str)),
(SQLTableSchema(table_name="LineListTransHTS", context_str=hts_linelist_str)),
(SQLTableSchema(table_name="LineListPrep", context_str=prep_str)),
(SQLTableSchema(table_name="LinelistHEI", context_str=hei_str)),
(SQLTableSchema(table_name="AggregateOTZEligibilityAndEnrollments", context_str=otzenroll_str)),
#(SQLTableSchema(table_name="AggregateDSD", context_str=dsd_str)),
(SQLTableSchema(table_name="AggregateARTHistory", context_str=aggtxcurr_str)),
]
obj_index = ObjectIndex.from_objects(
table_schema_objs, # A list of table schema objects
table_node_mapping, # An object responsible for mapping tables to nodes.
VectorStoreIndex, # for vector-based searching or indexing.
)
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
query_engine = SQLTableRetrieverQueryEngine(
sql_database,
obj_index.as_retriever(similarity_top_k=2),
)
preamble = ("Given an input question, first create a syntactically correct"
"query to run, then look at the results of the query and return the answer"
"You can order the results by a relevant column to return the most"
"interesting examples in the database."
"Pay attention to use only the column names that you can see in the schema"
"description. Be careful to not query for columns that do not exist."
"Pay attention to which column is in which table. Also, qualify column names"
"with the table name when needed.")
prompt_intro = (" Here is the prompt: ")
import gradio as gr
def texttosql(question: str, conversation_history: list[str]):
context = " ".join([item["user"] + " " + item["chatbot"] for item in conversation_history])
response = query_engine.query(preamble +
"the user previously asked and received the following: " +
context +
prompt_intro +
question)
conversation_history.append({"user": question, "chatbot": response.response})
return response.response,response.metadata["sql_query"] ,response.metadata["result"] , conversation_history
inputs = [gr.Textbox(lines=10, label="Question"),
gr.State(value=[])]
outputs = [
gr.Textbox(label="Chatbot Response", type="text"),
gr.Textbox(label="sql_query", autoscroll = False, type="text"),
gr.Textbox(label="Metadata_result", autoscroll = False, type="text"),
# gr.Textbox(label="Source 3", max_lines = 10, autoscroll = False, type="text"),
gr.State()
]
gr.Interface(fn=texttosql, inputs=inputs, outputs=outputs, title="txttosql Chatbot",
description="Enter a question and see the processed outputs in collapsible boxes.").launch()