llm_classifier / app /main.py
argmin's picture
wrap up
984fcd8
import streamlit as st
import pandas as pd
from utils.prompt import generate_prompts
from utils.classification import apply_classification
from utils.validation import generate_classification_model
from utils.api import get_openai_client
from utils.tokens import estimate_token_count
from utils.vis import display_metrics_as_table, display_model_config
from config.model_params import DEFAULT_PARAMS
st.set_page_config(layout="wide")
# Define the tabs
tab1, tab2 = st.tabs(["๐Ÿ“– Documentation", "๐Ÿค– Classifier"])
# Tab 1: Readme
with tab1:
readme_content = ''.join(open('README.md').read().split('---')[2:])
st.markdown(readme_content)
# Tab 2: Classifier
with tab2:
# Streamlit App Title
st.title("๐Ÿค– LLM-based Classifier")
# Upload Dataset
uploaded_file = st.sidebar.file_uploader("Upload a CSV file", type=["csv"])
if uploaded_file:
df = pd.read_csv(uploaded_file)
st.write("### Data Preview", df.head())
# Select Target Column
label_column = st.selectbox(
"Select target column (if available):",
["None"] + df.columns.tolist(),
index=0
)
if label_column == "None":
st.warning("No target column selected. The app will run in inference mode.")
label_column = None
filtered_columns = df.columns.tolist()
else:
# Ensure the label column is defined and excluded from features
df[label_column] = df[label_column].astype(str) # Convert to string
filtered_columns = [col for col in df.columns if col != label_column]
# Feature Selection
features = st.multiselect(
"Select features:",
filtered_columns,
default=filtered_columns if label_column is None else filtered_columns,
)
# Validate Features
if label_column in features:
st.error(f"Target column '{label_column}' cannot be included in features. Please remove it.")
st.stop()
if not features:
st.error("Please select at least one feature to proceed.")
st.stop()
# Specify Prediction Column Name
prediction_column = st.text_input(
"Enter the name of the column to store predictions:", "Predicted Label"
)
# Define Labels and Descriptions
if label_column:
# Automatically fetch unique values from the target column
unique_labels = df[label_column].unique()
# Initialize number of labels based on unique values
num_labels = len(unique_labels)
st.write(f"Automatically detected {num_labels} unique values in the target column.")
# Create columns for labels and descriptions
col1, col2 = st.columns(2)
# Populate labels and descriptions
label_descriptions = {}
for i, value in enumerate(unique_labels):
with col1:
label = st.text_input(
f"Label {i+1} name:",
value=str(value), # Auto-populate with unique value
key=f"label_name_{i}"
)
with col2:
description = st.text_input(
f"Label {i+1} description:",
value=f"", # Default description
key=f"label_desc_{i}"
)
label_descriptions[label] = description
else:
# Fallback for manual entry if no target column is selected
num_labels = st.number_input("Number of unique labels:", min_value=2, step=1)
# Create columns for labels and descriptions
col1, col2 = st.columns(2)
label_descriptions = {}
for i in range(int(num_labels)):
with col1:
label = st.text_input(f"Label {i+1} name:", key=f"label_name_{i}")
with col2:
description = st.text_input(f"Label {i+1} description:", key=f"label_desc_{i}")
label_descriptions[label] = description
# Compare user-provided labels with unique target values
if label_column:
# Convert label column to string
df[label_column] = df[label_column].astype(str)
# Get unique values in the target column
unique_target_values = set(df[label_column].unique())
n_unique_target_values = len(unique_target_values)
if n_unique_target_values > 20:
st.warning(
f"The selected column '{label_column}' has {n_unique_target_values} unique values, "
f"which may not be ideal as a target for classification."
)
proceed = st.checkbox(
f"I understand and still want to use '{label_column}' as the target column."
)
if not proceed:
st.stop()
# Get user-provided labels
user_provided_labels = set(label_descriptions.keys())
# Identify missing and extra labels
missing_labels = unique_target_values - user_provided_labels
extra_labels = user_provided_labels - unique_target_values
# Display warnings for discrepancies
if missing_labels:
st.warning(
f"The following values in the target column are not accounted for in the labels: {', '.join(map(str, missing_labels))}."
)
if extra_labels:
st.warning(
f"The following user-provided labels do not match any values in the target column: {', '.join(map(str, extra_labels))}."
)
# Few-Shot Prompting
use_few_shot = st.checkbox("Use few-shot prompting with examples from the target column", value=False)
if use_few_shot and label_column:
st.info("Few-shot prompting is enabled. Examples will be selected from the dataset.")
# Group by target column and select 2 examples per class
few_shot_examples = (
df.groupby(label_column, group_keys=False)
.apply(lambda group: group.sample(min(2, len(group)), random_state=42))
)
# Show the few-shot examples for reference
st.write("### Few-Shot Examples")
st.write(few_shot_examples[[*features, label_column]])
# Remove few-shot examples from the dataset
remaining_data = df.drop(few_shot_examples.index)
else:
few_shot_examples = None
remaining_data = df
# Limit rows based on user input to control costs
num_rows_to_send = st.number_input('Select number of rows to send to OpenAI ($$)',
min_value=1, max_value=len(remaining_data),
value=min(20, len(remaining_data)))
if len(remaining_data) > num_rows_to_send:
st.warning(f"Only the first {num_rows_to_send} rows of the remaining dataset will be sent to OpenAI to minimize costs.")
# Apply the limit correctly
limited_data = remaining_data.head(num_rows_to_send)
# Prepare Few-Shot Examples for Prompting
example_rows = []
if use_few_shot and few_shot_examples is not None:
for _, example in few_shot_examples.iterrows():
example_rows.append({
"features": {feature: example[feature] for feature in features},
"label": example[label_column],
})
# API Key and Model Parameters
openai_api_key = st.sidebar.text_input("Enter your OpenAI API Key:", type="password")
model_params = {
"model": st.selectbox(
"Model:",
DEFAULT_PARAMS["available_models"],
index=DEFAULT_PARAMS["available_models"].index(DEFAULT_PARAMS["model"])
),
"temperature": st.slider("Temperature:", min_value=0.0, max_value=1.0, value=DEFAULT_PARAMS["temperature"]),
"max_tokens": DEFAULT_PARAMS["max_tokens"],
}
display_model_config(DEFAULT_PARAMS)
verbose = st.checkbox("Verbose", value=False)
# Classification Button
if st.button("Run Classification"):
if not openai_api_key:
st.error("Please provide a valid OpenAI API Key.")
else:
# Initialize OpenAI client
client = get_openai_client(api_key=openai_api_key)
# Dynamically create the Pydantic model for validation
ClassificationOutput = generate_classification_model(list(label_descriptions.keys()))
# Create a placeholder for the progress bar
progress_bar = st.progress(0)
progress_text = st.empty()
# Function to classify a single row
def classify_row(row, index, total_rows):
# Update progress bar
progress_bar.progress((index + 1) / total_rows)
progress_text.text(f"Processing row {index + 1}/{total_rows}...")
# Generate system and user prompts
system_prompt, user_prompt = generate_prompts(
row=row.to_dict(),
label_descriptions=label_descriptions,
features=features,
example_rows=example_rows,
)
# Show the prompts in an expander for transparency
if verbose:
with st.expander(f"OpenAI Call Input for Row Index {row.name}"):
st.write("**System Prompt:**")
st.code(system_prompt)
st.write(f"Token Count (System Prompt): {estimate_token_count(system_prompt, model_params['model'])}")
st.write("**User Prompt:**")
st.code(user_prompt)
st.write(f"Token Count (User Prompt): {estimate_token_count(user_prompt, model_params['model'])}")
# Make the OpenAI call and validate the output
return apply_classification(
client=client,
model_params=model_params,
ClassificationOutput=ClassificationOutput,
system_prompt=system_prompt,
user_prompt=user_prompt,
verbose=verbose,
st=st
)
# Apply the classification to each row in the limited data
total_rows = len(limited_data)
predictions = []
for index, row in limited_data.iterrows():
prediction = classify_row(row, index, total_rows)
predictions.append(prediction)
# Add predictions to the DataFrame
limited_data[prediction_column] = predictions
# Reset progress bar and text
progress_bar.empty()
progress_text.empty()
# Display Predictions
st.write(f"### Predictions ({prediction_column})", limited_data)
# Evaluate if ground truth is available
if label_column in limited_data.columns:
from utils.evaluation import evaluate_predictions
report = evaluate_predictions(limited_data[label_column], limited_data[prediction_column])
st.write("### Evaluation Metrics")
display_metrics_as_table(report)
else:
st.warning(f"Inference mode: No target column provided, so no evaluation metrics are available.")
# Count predictions
label_counts = limited_data[prediction_column].value_counts().reset_index()
label_counts.columns = ["Label", "Count"]
st.subheader("Prediction Statistics")
st.table(label_counts)
else:
st.write('Drag and drop a CSV to get started.')
st.markdown("""
Some ideas here:
- (Binary) https://www.kaggle.com/datasets/ozlerhakan/spam-or-not-spam-dataset
- (Multi-class) https://www.kaggle.com/datasets/mdismielhossenabir/sentiment-analysis
- (Multi-class) https://www.kaggle.com/datasets/pashupatigupta/emotion-detection-from-text
""")