prefilter_app / app.py
mtyrrell's picture
Fixed local testing config
8560376
import torch
try:
print(f"Is CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
try:
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
except Exception as e:
print(f"Error getting CUDA device name: {str(e)}")
else:
print("No CUDA device available - using CPU")
except Exception as e:
print(f"Error checking CUDA availability: {str(e)}")
print("Continuing with CPU...")
import streamlit as st
import os
from huggingface_hub import login
from datetime import datetime
from openai import OpenAI
from src.auth import validate_login
from src.utils import create_excel, setup_logging, getconfig
from src.pipeline import process_data
setup_logging()
import logging
from io import BytesIO
logger = logging.getLogger(__name__)
# Local
# from dotenv import load_dotenv
# load_dotenv()
config = getconfig("config.cfg")
@st.cache_resource
def get_azure_openai_client():
"""Initialize and cache Azure OpenAI client for the session"""
try:
AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION")
AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
if not all([AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_VERSION, AZURE_OPENAI_API_KEY]):
raise ValueError("Missing required Azure OpenAI environment variables. Please check your .env file.")
client = OpenAI(api_key=AZURE_OPENAI_API_KEY, base_url=AZURE_OPENAI_ENDPOINT)
logger.info("Azure OpenAI client initialized successfully")
return client
except Exception as e:
logger.error(f"Failed to initialize Azure OpenAI client: {str(e)}")
raise
def get_azure_deployment():
"""Get Azure OpenAI deployment name from config file"""
try:
config = getconfig("config.cfg")
deployment = config.get("deployments", "DEPLOYMENT")
logger.info(f"Using Azure OpenAI deployment: {deployment}")
return deployment
except Exception as e:
logger.error(f"Failed to read deployment from config: {str(e)}. Using default deployment.")
deployment = "gpt-4o-mini"
return deployment
# Main app logic
def main():
# Temporarily set authentication to True for testing
if 'authenticated' not in st.session_state:
st.session_state['authenticated'] = False
if st.session_state['authenticated']:
# Remove login success message for testing
hf_token = os.environ["HF_TOKEN"]
login(token=hf_token, add_to_git_credential=True)
# Initialize session state variables
if 'data_processed' not in st.session_state:
st.session_state['data_processed'] = False
st.session_state['df'] = None
# Main Streamlit app
st.title('Application Pre-Filtering Tool')
# Sidebar (filters)
with st.sidebar:
with st.expander("ℹ️ - Instructions", expanded=False):
st.markdown(
"""
1. **Download the Excel Template file (below)**
2. **[OPTIONAL]: Select the desired filtering sensitivity level (below)**
3. **Copy/paste the requisite application data in the template file. Best practice is to 'paste as values'**
4. **Upload the template file in the area to the right (or click browse files)**
5. **Click 'Start Analysis'**
The tool will start processing the uploaded application data. This can take some time
depending on the number of applications and the length of text in each. For example, a file with 1000 applications
could be expected to take approximately 5 minutes.
***NOTE** - you can also simply rename the column headers in your own file. The headers must match the column names in the template for the tool to run properly.*
"""
)
# Excel file download
st.download_button(
label="Download Excel Template",
data=create_excel(),
file_name="upload_template.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
# get sensitivity level for use in review / reject (ref. process_data function)
sens_options = {
"Low": 3,
"Medium": 4,
"High": 5,
}
sens_input = st.sidebar.radio(label = 'Select the Sensitivity Level [OPTIONAL]',
help = 'Decreasing the level of sensitivity results in less \
applications filtered out. This also \
reduces the probability of false negatives (FNs). The rate of \
FNs at the lowest setting is approximately 6 percent, and \
approaches 13 percent at the highest setting. \
NOTE: changing this setting does not affect the raw data in the CSV output file (only the labels)',
options = list(sens_options.keys()),
index = list(sens_options.keys()).index("High"),
horizontal = False)
sens_level = sens_options[sens_input]
logger.info(f"Sensitivity level applied: {sens_level}")
with st.expander("ℹ️ - About this app", expanded=False):
st.write(
"""
This tool provides an interface for running an automated preliminary assessment of applications for a call for applications.
The tool functions by running selected text fields from the application through a series of LLMs fine-tuned for text classification (ref. diagram below).
The resulting output classifications are used to compute a score and a suggested pre-filtering action. The tool has been tested against
human assessors and exhibits an extremely low false negative rate (<6%) at a Sensitivity Level of 'Low' (i.e. rejection threshold for predicted score < 4).
""")
st.image('assets/pipeline.png')
uploaded_file = st.file_uploader("Select a file containing application pre-filtering data (see instructions in the sidebar)")
# Add session state variables if they don't exist
if 'show_button' not in st.session_state:
st.session_state['show_button'] = True
if 'processing' not in st.session_state:
st.session_state['processing'] = False
if 'data_processed' not in st.session_state:
st.session_state['data_processed'] = False
# Only show the button if show_button is True and file is uploaded and not processing
if uploaded_file is not None and st.session_state['show_button'] and not st.session_state['processing']:
if st.button("Start Analysis", key="start_analysis"):
st.session_state['show_button'] = False
st.session_state['processing'] = True
st.rerun()
# If we're processing, show the processing logic
if st.session_state['processing']:
try:
logger.info(f"File uploaded: {uploaded_file.name}")
if not st.session_state['data_processed']:
logger.info("Starting data processing")
try:
# Initialize Azure OpenAI client and get deployment name
azure_client = get_azure_openai_client()
azure_deployment = get_azure_deployment()
st.session_state['df'] = process_data(
uploaded_file,
sens_level,
azure_client,
azure_deployment
)
logger.info("Data processing completed successfully")
st.session_state['data_processed'] = True
except ValueError as e:
# Handle specific validation errors
logger.error(f"Validation error: {str(e)}")
st.error(str(e))
st.session_state['show_button'] = True
st.session_state['processing'] = False
st.rerun()
except Exception as e:
# Handle other unexpected errors
logger.error(f"Error in process_data: {str(e)}")
st.error("An unexpected error occurred. Please check your input file and try again.")
st.session_state['show_button'] = True
st.session_state['processing'] = False
st.rerun()
df = st.session_state['df']
def reset_button_state():
st.session_state['show_button'] = True
st.session_state['processing'] = False
st.session_state['data_processed'] = False
# Create Excel buffer
excel_buffer = BytesIO()
df.to_excel(excel_buffer, index=False, engine='openpyxl')
excel_buffer.seek(0)
current_datetime = datetime.now().strftime('%d-%m-%Y_%H-%M-%S')
output_filename = f'processed_applications_{current_datetime}.xlsx'
st.download_button(
label="Download Analysis Data File",
data=excel_buffer,
file_name=output_filename,
mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
on_click=reset_button_state
)
except Exception as e:
logger.error(f"Error processing file: {str(e)}")
st.error("Failed to process the file. Please ensure your column names match the template file.")
st.session_state['show_button'] = True
st.session_state['processing'] = False
st.rerun()
# Comment out for testing
else:
username = st.text_input("Username")
password = st.text_input("Password", type="password")
if st.button("Login"):
if validate_login(username, password):
st.session_state['authenticated'] = True
st.rerun()
else:
st.error("Incorrect username or password")
main()