Spaces:
Build error
Build error
Commit ·
bdc8737
1
Parent(s): a865eda
fix: bugs
Browse files
model_finetuning/components/model_training_component.py
CHANGED
|
@@ -18,7 +18,7 @@ import os
|
|
| 18 |
|
| 19 |
def model_training():
|
| 20 |
dataset_path = st.session_state.get("selected_dataset", None)
|
| 21 |
-
if not dataset_path:
|
| 22 |
st.error("Please select a dataset to proceed.")
|
| 23 |
return
|
| 24 |
|
|
@@ -36,6 +36,8 @@ def model_training():
|
|
| 36 |
|
| 37 |
test_size = st.selectbox("Select Test Size", options=[0.1, 0.2, 0.3, 0.4, 0.5], index=1)
|
| 38 |
train_df, val_df = train_test_split(annotations_df, test_size=test_size, random_state=42)
|
|
|
|
|
|
|
| 39 |
st.write(f"Train Size: {len(train_df)} | Validation Size: {len(val_df)}")
|
| 40 |
col1, col2 = st.columns(2)
|
| 41 |
with col1:
|
|
@@ -50,6 +52,8 @@ def model_training():
|
|
| 50 |
for batch_size in batch_size_options:
|
| 51 |
if batch_size > ideal_batch_size:
|
| 52 |
ideal_batch_size_index = batch_size_options.index(batch_size) - 1
|
|
|
|
|
|
|
| 53 |
break
|
| 54 |
batch_size = st.selectbox("Select Batch Size", options=[2, 4, 8, 16, 32, 64, 128], index=ideal_batch_size_index)
|
| 55 |
|
|
|
|
| 18 |
|
| 19 |
def model_training():
|
| 20 |
dataset_path = st.session_state.get("selected_dataset", None)
|
| 21 |
+
if not dataset_path or dataset_path == "":
|
| 22 |
st.error("Please select a dataset to proceed.")
|
| 23 |
return
|
| 24 |
|
|
|
|
| 36 |
|
| 37 |
test_size = st.selectbox("Select Test Size", options=[0.1, 0.2, 0.3, 0.4, 0.5], index=1)
|
| 38 |
train_df, val_df = train_test_split(annotations_df, test_size=test_size, random_state=42)
|
| 39 |
+
if len(train_df) < 2:
|
| 40 |
+
st.error("Not enough data to train the model.")
|
| 41 |
st.write(f"Train Size: {len(train_df)} | Validation Size: {len(val_df)}")
|
| 42 |
col1, col2 = st.columns(2)
|
| 43 |
with col1:
|
|
|
|
| 52 |
for batch_size in batch_size_options:
|
| 53 |
if batch_size > ideal_batch_size:
|
| 54 |
ideal_batch_size_index = batch_size_options.index(batch_size) - 1
|
| 55 |
+
if ideal_batch_size_index < 0:
|
| 56 |
+
ideal_batch_size_index = 0
|
| 57 |
break
|
| 58 |
batch_size = st.selectbox("Select Batch Size", options=[2, 4, 8, 16, 32, 64, 128], index=ideal_batch_size_index)
|
| 59 |
|