Spaces:
Sleeping
Sleeping
| import pickle | |
| import pandas as pd | |
| import streamlit as st | |
| from datasets import Dataset | |
| import model | |
| from utils import check_columns, count_labels | |
| # Main function to run the Streamlit app | |
| def main(): | |
| # Set app title | |
| st.title("Few Shot Learning Demo using SetFit") | |
| # Display the session ID | |
| st.write(f"Session ID: {st.session_state.key}") | |
| session_id = st.session_state.key | |
| # Create file uploader | |
| uploaded_file = st.file_uploader("Choose a CSV file to upload", type="csv") | |
| # Check if file was uploaded | |
| if uploaded_file is not None: | |
| # Read CSV file into pandas DataFrame | |
| df = pd.read_csv(uploaded_file) | |
| # Check if DataFrame has expected columns | |
| if check_columns(df): | |
| # Display DataFrame as a table | |
| st.write(df) | |
| # Calculate the number of instances of each label class | |
| label_counts = count_labels(df) | |
| st.write(f"Number of instances of each label class: {label_counts}") | |
| labels = set(df["label"].tolist()) | |
| label_map = {label: idx for idx, label in enumerate(labels)} | |
| df["label"] = df["label"].map(label_map) | |
| dataset = Dataset.from_pandas(df) | |
| model_name = st.text_input("Input the model name") | |
| pretrained_model_options = ["all-MiniLM-L6-v2", "paraphrase-MiniLM-L3-v2"] | |
| pretrained_model = st.selectbox( | |
| "Select a pretrained model", options=pretrained_model_options | |
| ) | |
| # Add Train button | |
| if st.button("Train"): | |
| # Train the model | |
| with st.spinner("Training model..."): | |
| model_path = model.run_setfit_training( | |
| session_id, | |
| pretrained_model, | |
| model_name, | |
| dataset, | |
| 1, | |
| 10, | |
| ) | |
| st.write(f"Model checkpoint saved {model_path.split('/')[-1]}") | |
| label_map = {v: k for k, v in label_map.items()} | |
| with open(f"{model_path}/label.pkl", "wb") as f: | |
| pickle.dump(label_map, f) | |
| st.write("Training Finished") | |
| st.write("Go to Validation Page") | |
| else: | |
| st.error("File must have 'text' and 'label' columns.") | |