Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| from utils.prompt import generate_prompts | |
| from utils.classification import apply_classification | |
| from utils.validation import generate_classification_model | |
| from utils.api import get_openai_client | |
| from utils.tokens import estimate_token_count | |
| from utils.vis import display_metrics_as_table, display_model_config | |
| from config.model_params import DEFAULT_PARAMS | |
| st.set_page_config(layout="wide") | |
| # Define the tabs | |
| tab1, tab2 = st.tabs(["๐ Documentation", "๐ค Classifier"]) | |
| # Tab 1: Readme | |
| with tab1: | |
| readme_content = ''.join(open('README.md').read().split('---')[2:]) | |
| st.markdown(readme_content) | |
| # Tab 2: Classifier | |
| with tab2: | |
| # Streamlit App Title | |
| st.title("๐ค LLM-based Classifier") | |
| # Upload Dataset | |
| uploaded_file = st.sidebar.file_uploader("Upload a CSV file", type=["csv"]) | |
| if uploaded_file: | |
| df = pd.read_csv(uploaded_file) | |
| st.write("### Data Preview", df.head()) | |
| # Select Target Column | |
| label_column = st.selectbox( | |
| "Select target column (if available):", | |
| ["None"] + df.columns.tolist(), | |
| index=0 | |
| ) | |
| if label_column == "None": | |
| st.warning("No target column selected. The app will run in inference mode.") | |
| label_column = None | |
| filtered_columns = df.columns.tolist() | |
| else: | |
| # Ensure the label column is defined and excluded from features | |
| df[label_column] = df[label_column].astype(str) # Convert to string | |
| filtered_columns = [col for col in df.columns if col != label_column] | |
| # Feature Selection | |
| features = st.multiselect( | |
| "Select features:", | |
| filtered_columns, | |
| default=filtered_columns if label_column is None else filtered_columns, | |
| ) | |
| # Validate Features | |
| if label_column in features: | |
| st.error(f"Target column '{label_column}' cannot be included in features. Please remove it.") | |
| st.stop() | |
| if not features: | |
| st.error("Please select at least one feature to proceed.") | |
| st.stop() | |
| # Specify Prediction Column Name | |
| prediction_column = st.text_input( | |
| "Enter the name of the column to store predictions:", "Predicted Label" | |
| ) | |
| # Define Labels and Descriptions | |
| if label_column: | |
| # Automatically fetch unique values from the target column | |
| unique_labels = df[label_column].unique() | |
| # Initialize number of labels based on unique values | |
| num_labels = len(unique_labels) | |
| st.write(f"Automatically detected {num_labels} unique values in the target column.") | |
| # Create columns for labels and descriptions | |
| col1, col2 = st.columns(2) | |
| # Populate labels and descriptions | |
| label_descriptions = {} | |
| for i, value in enumerate(unique_labels): | |
| with col1: | |
| label = st.text_input( | |
| f"Label {i+1} name:", | |
| value=str(value), # Auto-populate with unique value | |
| key=f"label_name_{i}" | |
| ) | |
| with col2: | |
| description = st.text_input( | |
| f"Label {i+1} description:", | |
| value=f"", # Default description | |
| key=f"label_desc_{i}" | |
| ) | |
| label_descriptions[label] = description | |
| else: | |
| # Fallback for manual entry if no target column is selected | |
| num_labels = st.number_input("Number of unique labels:", min_value=2, step=1) | |
| # Create columns for labels and descriptions | |
| col1, col2 = st.columns(2) | |
| label_descriptions = {} | |
| for i in range(int(num_labels)): | |
| with col1: | |
| label = st.text_input(f"Label {i+1} name:", key=f"label_name_{i}") | |
| with col2: | |
| description = st.text_input(f"Label {i+1} description:", key=f"label_desc_{i}") | |
| label_descriptions[label] = description | |
| # Compare user-provided labels with unique target values | |
| if label_column: | |
| # Convert label column to string | |
| df[label_column] = df[label_column].astype(str) | |
| # Get unique values in the target column | |
| unique_target_values = set(df[label_column].unique()) | |
| n_unique_target_values = len(unique_target_values) | |
| if n_unique_target_values > 20: | |
| st.warning( | |
| f"The selected column '{label_column}' has {n_unique_target_values} unique values, " | |
| f"which may not be ideal as a target for classification." | |
| ) | |
| proceed = st.checkbox( | |
| f"I understand and still want to use '{label_column}' as the target column." | |
| ) | |
| if not proceed: | |
| st.stop() | |
| # Get user-provided labels | |
| user_provided_labels = set(label_descriptions.keys()) | |
| # Identify missing and extra labels | |
| missing_labels = unique_target_values - user_provided_labels | |
| extra_labels = user_provided_labels - unique_target_values | |
| # Display warnings for discrepancies | |
| if missing_labels: | |
| st.warning( | |
| f"The following values in the target column are not accounted for in the labels: {', '.join(map(str, missing_labels))}." | |
| ) | |
| if extra_labels: | |
| st.warning( | |
| f"The following user-provided labels do not match any values in the target column: {', '.join(map(str, extra_labels))}." | |
| ) | |
| # Few-Shot Prompting | |
| use_few_shot = st.checkbox("Use few-shot prompting with examples from the target column", value=False) | |
| if use_few_shot and label_column: | |
| st.info("Few-shot prompting is enabled. Examples will be selected from the dataset.") | |
| # Group by target column and select 2 examples per class | |
| few_shot_examples = ( | |
| df.groupby(label_column, group_keys=False) | |
| .apply(lambda group: group.sample(min(2, len(group)), random_state=42)) | |
| ) | |
| # Show the few-shot examples for reference | |
| st.write("### Few-Shot Examples") | |
| st.write(few_shot_examples[[*features, label_column]]) | |
| # Remove few-shot examples from the dataset | |
| remaining_data = df.drop(few_shot_examples.index) | |
| else: | |
| few_shot_examples = None | |
| remaining_data = df | |
| # Limit rows based on user input to control costs | |
| num_rows_to_send = st.number_input('Select number of rows to send to OpenAI ($$)', | |
| min_value=1, max_value=len(remaining_data), | |
| value=min(20, len(remaining_data))) | |
| if len(remaining_data) > num_rows_to_send: | |
| st.warning(f"Only the first {num_rows_to_send} rows of the remaining dataset will be sent to OpenAI to minimize costs.") | |
| # Apply the limit correctly | |
| limited_data = remaining_data.head(num_rows_to_send) | |
| # Prepare Few-Shot Examples for Prompting | |
| example_rows = [] | |
| if use_few_shot and few_shot_examples is not None: | |
| for _, example in few_shot_examples.iterrows(): | |
| example_rows.append({ | |
| "features": {feature: example[feature] for feature in features}, | |
| "label": example[label_column], | |
| }) | |
| # API Key and Model Parameters | |
| openai_api_key = st.sidebar.text_input("Enter your OpenAI API Key:", type="password") | |
| model_params = { | |
| "model": st.selectbox( | |
| "Model:", | |
| DEFAULT_PARAMS["available_models"], | |
| index=DEFAULT_PARAMS["available_models"].index(DEFAULT_PARAMS["model"]) | |
| ), | |
| "temperature": st.slider("Temperature:", min_value=0.0, max_value=1.0, value=DEFAULT_PARAMS["temperature"]), | |
| "max_tokens": DEFAULT_PARAMS["max_tokens"], | |
| } | |
| display_model_config(DEFAULT_PARAMS) | |
| verbose = st.checkbox("Verbose", value=False) | |
| # Classification Button | |
| if st.button("Run Classification"): | |
| if not openai_api_key: | |
| st.error("Please provide a valid OpenAI API Key.") | |
| else: | |
| # Initialize OpenAI client | |
| client = get_openai_client(api_key=openai_api_key) | |
| # Dynamically create the Pydantic model for validation | |
| ClassificationOutput = generate_classification_model(list(label_descriptions.keys())) | |
| # Create a placeholder for the progress bar | |
| progress_bar = st.progress(0) | |
| progress_text = st.empty() | |
| # Function to classify a single row | |
| def classify_row(row, index, total_rows): | |
| # Update progress bar | |
| progress_bar.progress((index + 1) / total_rows) | |
| progress_text.text(f"Processing row {index + 1}/{total_rows}...") | |
| # Generate system and user prompts | |
| system_prompt, user_prompt = generate_prompts( | |
| row=row.to_dict(), | |
| label_descriptions=label_descriptions, | |
| features=features, | |
| example_rows=example_rows, | |
| ) | |
| # Show the prompts in an expander for transparency | |
| if verbose: | |
| with st.expander(f"OpenAI Call Input for Row Index {row.name}"): | |
| st.write("**System Prompt:**") | |
| st.code(system_prompt) | |
| st.write(f"Token Count (System Prompt): {estimate_token_count(system_prompt, model_params['model'])}") | |
| st.write("**User Prompt:**") | |
| st.code(user_prompt) | |
| st.write(f"Token Count (User Prompt): {estimate_token_count(user_prompt, model_params['model'])}") | |
| # Make the OpenAI call and validate the output | |
| return apply_classification( | |
| client=client, | |
| model_params=model_params, | |
| ClassificationOutput=ClassificationOutput, | |
| system_prompt=system_prompt, | |
| user_prompt=user_prompt, | |
| verbose=verbose, | |
| st=st | |
| ) | |
| # Apply the classification to each row in the limited data | |
| total_rows = len(limited_data) | |
| predictions = [] | |
| for index, row in limited_data.iterrows(): | |
| prediction = classify_row(row, index, total_rows) | |
| predictions.append(prediction) | |
| # Add predictions to the DataFrame | |
| limited_data[prediction_column] = predictions | |
| # Reset progress bar and text | |
| progress_bar.empty() | |
| progress_text.empty() | |
| # Display Predictions | |
| st.write(f"### Predictions ({prediction_column})", limited_data) | |
| # Evaluate if ground truth is available | |
| if label_column in limited_data.columns: | |
| from utils.evaluation import evaluate_predictions | |
| report = evaluate_predictions(limited_data[label_column], limited_data[prediction_column]) | |
| st.write("### Evaluation Metrics") | |
| display_metrics_as_table(report) | |
| else: | |
| st.warning(f"Inference mode: No target column provided, so no evaluation metrics are available.") | |
| # Count predictions | |
| label_counts = limited_data[prediction_column].value_counts().reset_index() | |
| label_counts.columns = ["Label", "Count"] | |
| st.subheader("Prediction Statistics") | |
| st.table(label_counts) | |
| else: | |
| st.write('Drag and drop a CSV to get started.') | |
| st.markdown(""" | |
| Some ideas here: | |
| - (Binary) https://www.kaggle.com/datasets/ozlerhakan/spam-or-not-spam-dataset | |
| - (Multi-class) https://www.kaggle.com/datasets/mdismielhossenabir/sentiment-analysis | |
| - (Multi-class) https://www.kaggle.com/datasets/pashupatigupta/emotion-detection-from-text | |
| """) |