Spaces:
Runtime error
Runtime error
| # Set the page config | |
| import streamlit as st | |
| st.set_page_config( | |
| page_title="Model_Training", | |
| page_icon=":open_file_folder:", | |
| layout="wide", | |
| initial_sidebar_state="collapsed", | |
| ) | |
| # Importing necessary libraries | |
| import utils | |
| import streamlit as st | |
| import Functions.model_training_functions as model_training_functions | |
| # Display the page title | |
| st.title("Model Training") | |
| # # Clear the Streamlit session state on the first load of the page | |
| # utils.clear_session_state_on_first_load("model_training_clear") | |
| # List of session state keys to initialize if they are not already present | |
| session_state_keys = [ | |
| "file_uploader_split_key_training", | |
| "file_uploader_train_key_training", | |
| "file_uploader_val_key_training", | |
| "file_uploader_test_key_training", | |
| "number_input_train_key", | |
| "number_input_val_key", | |
| "number_input_test_key", | |
| "split_method_key", | |
| "training_type_key", | |
| "class_labels_input_key_training", | |
| ] | |
| # Iterate through each session state key | |
| for key in session_state_keys: | |
| # Check if the key is not already in the session state | |
| if key not in st.session_state: | |
| # Initialize the key with a dictionary containing itself set to True | |
| st.session_state[key] = {key: True} | |
| # Initialize session state variables if not present | |
| if "validation_triggered" not in st.session_state: | |
| st.session_state["validation_triggered"] = False | |
| if "uploaded_files_cache_processing" not in st.session_state: | |
| st.session_state["uploaded_files_cache_processing"] = False | |
| # Initialize session state variables if not present | |
| if "is_valid" not in st.session_state: | |
| st.session_state["is_valid"] = False | |
| # Container for file uploaders | |
| file_uploader_container = st.container() | |
| # Dictionary for mapping the user-friendly terms to technical label types | |
| label_type_mapping = {"Object Detection": "Bboxes", "Instance Segmentation": "Masks"} | |
| # Create two columns for widgets | |
| column_select_training, column_split_method = st.columns(2) | |
| # Dropdown for selecting the training type | |
| with column_select_training: | |
| selected_training = st.selectbox( | |
| "Select the training type:", | |
| list(label_type_mapping.keys()), | |
| index=0, | |
| on_change=utils.reset_validation_trigger, | |
| key=st.session_state["training_type_key"], | |
| ) | |
| # Getting the corresponding label type | |
| label_type = label_type_mapping[selected_training] | |
| # Toggle for choosing the split method | |
| with column_split_method: | |
| split_method = st.radio( | |
| "Select the dataset split method:", | |
| ["Percentage Split", "Direct Upload"], | |
| horizontal=True, | |
| on_change=utils.reset_validation_trigger, | |
| key=st.session_state["split_method_key"], | |
| ) | |
| # Text area for user to input class labels | |
| class_labels_input = st.text_area( | |
| "Enter class labels, separated by commas:", | |
| utils.sample_class_labels, | |
| on_change=utils.reset_validation_trigger, | |
| key=st.session_state["class_labels_input_key_training"], | |
| ) # Example default values | |
| class_labels_input = ( | |
| class_labels_input.strip() | |
| ) # Remove unecessary space form start and end | |
| # Generating a dictionary mapping class IDs to their respective labels | |
| try: | |
| class_labels = [ | |
| label.strip() for label in class_labels_input.split(",") if label.strip() | |
| ] | |
| class_dict = {i: label for i, label in enumerate(class_labels)} | |
| # Invert the class_dict to map class names to class IDs | |
| class_names_to_ids = {v: k for k, v in class_dict.items()} | |
| except Exception as e: | |
| st.warning( | |
| "Invalid format for class labels. Please enter labels separated by commas.", | |
| icon="⚠️", | |
| ) | |
| class_dict, class_names_to_ids = ( | |
| {}, | |
| {}, | |
| ) # Keeping class_dict and class_names_to_ids as an empty | |
| # Note to users | |
| st.markdown( | |
| """ | |
| <div style='text-align: justify;'> | |
| <b>Note to Users:</b> | |
| <ul> | |
| <li>When moving to another page or if you wish to upload a new set of images and labels, don't forget to hit the <b>Reset</b> button. This helps in faster computation and frees up unused memory, ensuring smoother operation.</li> | |
| <li>Select the training type, class labels, dataset split method and its parameters before uploading large data for faster computation and more efficient processing.</li> | |
| </ul> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Create two columns for input percentages | |
| validate_button_col, reset_button_col = st.columns(2) | |
| with reset_button_col: | |
| # Check if the 'Reset' button is pressed | |
| if st.button("Reset", use_container_width=True): | |
| # Clear folders | |
| model_training_functions.delete_and_recreate_folder( | |
| model_training_functions.get_path("output") | |
| ) | |
| model_training_functions.clear_data_folders() | |
| # List of session state keys that need to be reset | |
| session_state_keys = [ | |
| "file_uploader_split_key_training", | |
| "file_uploader_train_key_training", | |
| "file_uploader_val_key_training", | |
| "file_uploader_test_key_training", | |
| "number_input_train_key", | |
| "number_input_val_key", | |
| "number_input_test_key", | |
| "split_method_key", | |
| "training_type_key", | |
| "class_labels_input_key_training", | |
| ] | |
| # Iterate through each session state key | |
| for session_state_key in session_state_keys: | |
| # Toggle the keys to reset their states | |
| current_value = st.session_state[session_state_key][session_state_key] | |
| updated_value = not current_value # Invert the current value | |
| # Update each key in the session state with the toggled value | |
| st.session_state[session_state_key] = {session_state_key: updated_value} | |
| # Clear all other session state keys except for widget_state_keys | |
| for key in list(st.session_state.keys()): | |
| if key not in session_state_keys: | |
| del st.session_state[key] | |
| # Clear global variables except for protected and Streamlit module | |
| global_vars = list(globals().keys()) | |
| vars_to_delete = [ | |
| var for var in global_vars if not var.startswith("_") and var != "st" | |
| ] | |
| for var in vars_to_delete: | |
| del globals()[var] | |
| # Clear the Streamlit caches | |
| st.cache_resource.clear() | |
| st.cache_data.clear() | |
| # Rerun the app to reflect the reset state | |
| st.rerun() | |
| # Code for "Percentage Split" method | |
| if split_method == "Percentage Split": | |
| with file_uploader_container: | |
| # User uploads images and labels | |
| utils.display_file_uploader( | |
| "uploaded_files", | |
| "Choose images and labels...", | |
| st.session_state["file_uploader_split_key_training"], | |
| st.session_state["uploaded_files_cache_processing"], | |
| ) | |
| # Create three columns for input percentages | |
| col1, col2, col3 = st.columns(3) | |
| # User specifies split percentages | |
| train_pct = col1.number_input( | |
| "Train Set Percentage", | |
| 0, | |
| 100, | |
| 70, | |
| 1, | |
| on_change=utils.reset_validation_trigger, | |
| key=st.session_state["number_input_train_key"], | |
| ) | |
| test_pct = col2.number_input( | |
| "Test Set Percentage", | |
| 0, | |
| 100, | |
| 15, | |
| 1, | |
| on_change=utils.reset_validation_trigger, | |
| key=st.session_state["number_input_val_key"], | |
| ) | |
| val_pct = col3.number_input( | |
| "Validation Set Percentage", | |
| 0, | |
| 100, | |
| 15, | |
| 1, | |
| on_change=utils.reset_validation_trigger, | |
| key=st.session_state["number_input_test_key"], | |
| ) | |
| # Check if the total percentage equals 100% | |
| pct_check = train_pct + test_pct + val_pct | |
| # Validating the input percentages | |
| pct_condition_check = ( | |
| pct_check == 100 | |
| and train_pct > 0 | |
| and val_pct > 0 | |
| and model_training_functions.check_min_images( | |
| len(st.session_state["uploaded_files"]), train_pct, val_pct, test_pct | |
| ) | |
| ) | |
| if not pct_condition_check: | |
| file_uploader_container.warning( | |
| "The percentages for train, test, and validation sets should add up to 100%, and train and validation set should not be empty.", | |
| icon="⚠️", | |
| ) | |
| # Button to trigger validation | |
| if validate_button_col.button("Validate Input", use_container_width=True): | |
| st.session_state["validation_triggered"] = True | |
| st.session_state["is_valid"] = model_training_functions.check_valid_labels( | |
| st.session_state["uploaded_files"], label_type, class_dict | |
| ) | |
| if st.session_state["is_valid"]: | |
| model_training_functions.create_yolo_config_file( | |
| model_training_functions.get_path("config"), | |
| class_labels, | |
| ) | |
| model_training_functions.clear_data_folders() | |
| paired_files = model_training_functions.pair_files( | |
| st.session_state["uploaded_files"] | |
| ) | |
| model_training_functions.split_and_save_files( | |
| paired_files, train_pct, test_pct | |
| ) | |
| # Process files if input is valid | |
| if st.session_state["validation_triggered"] and ( | |
| pct_condition_check and st.session_state["is_valid"] | |
| ): | |
| model_training_functions.start_yolo_training(selected_training, class_labels) | |
| else: | |
| # Display a warning message if the validation is not successful or conditions are not met | |
| st.warning( | |
| "Please upload valid input, select valid parameters, and click **Validate Input**.", | |
| icon="⚠️", | |
| ) | |
| # Code for "Direct Upload" method | |
| elif split_method == "Direct Upload": | |
| with file_uploader_container: | |
| # Create three columns for uploading train, val, and test files | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| utils.display_file_uploader( | |
| "uploaded_train_files", | |
| "Upload Training Images and Labels", | |
| st.session_state["file_uploader_train_key_training"], | |
| st.session_state["uploaded_files_cache_processing"], | |
| ) | |
| with col2: | |
| utils.display_file_uploader( | |
| "uploaded_val_files", | |
| "Upload Validation Images and Labels", | |
| st.session_state["file_uploader_val_key_training"], | |
| st.session_state["uploaded_files_cache_processing"], | |
| ) | |
| with col3: | |
| utils.display_file_uploader( | |
| "uploaded_test_files", | |
| "Upload Test Images and Labels", | |
| st.session_state["file_uploader_test_key_training"], | |
| st.session_state["uploaded_files_cache_processing"], | |
| ) | |
| # Check for valid input | |
| pct_condition_check = ( | |
| len(st.session_state["uploaded_train_files"]) > 0 | |
| and len(st.session_state["uploaded_val_files"]) > 0 | |
| ) | |
| if not pct_condition_check: | |
| file_uploader_container.warning( | |
| "The train and validation set should not be empty.", | |
| icon="⚠️", | |
| ) | |
| # Button to trigger validation | |
| if validate_button_col.button("Validate Input", use_container_width=True): | |
| st.session_state["validation_triggered"] = True | |
| st.session_state["is_valid"] = model_training_functions.check_valid_labels( | |
| st.session_state["uploaded_train_files"] | |
| + st.session_state["uploaded_val_files"] | |
| + st.session_state["uploaded_test_files"], | |
| label_type, | |
| class_dict, | |
| ) | |
| if st.session_state["is_valid"]: | |
| model_training_functions.create_yolo_config_file( | |
| model_training_functions.get_path("config"), | |
| class_labels, | |
| ) | |
| model_training_functions.clear_data_folders() | |
| model_training_functions.save_files_to_folder( | |
| st.session_state["uploaded_train_files"], "train" | |
| ) | |
| model_training_functions.save_files_to_folder( | |
| st.session_state["uploaded_val_files"], "val" | |
| ) | |
| # Only save test files if they are uploaded | |
| if len(st.session_state["uploaded_test_files"]) > 0: | |
| model_training_functions.save_files_to_folder( | |
| st.session_state["uploaded_test_files"], "test" | |
| ) | |
| # Process files if input is valid | |
| if st.session_state["validation_triggered"] and ( | |
| pct_condition_check and st.session_state["is_valid"] | |
| ): | |
| model_training_functions.start_yolo_training(selected_training, class_labels) | |
| else: | |
| # Display a warning message if the validation is not successful or conditions are not met | |
| st.warning( | |
| "Please upload valid input, select valid parameters, and click **Validate Input**.", | |
| icon="⚠️", | |
| ) | |