pavanmutha commited on
Commit
7e1b7d0
·
verified ·
1 Parent(s): 48f5f7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -184,12 +184,19 @@ def compare_models():
184
  return results_df
185
 
186
  # 1. prepare_data should come first
187
- def prepare_data(df, target_column="target"):
188
  from sklearn.model_selection import train_test_split
 
 
 
 
 
189
  X = df.drop(columns=[target_column])
190
  y = df[target_column]
 
191
  return train_test_split(X, y, test_size=0.2, random_state=42)
192
 
 
193
  def train_model(_):
194
  try:
195
  wandb.login(key=os.environ.get("WANDB_API_KEY"))
 
184
  return results_df
185
 
186
  # 1. prepare_data should come first
187
+ def prepare_data(df, target_column=None):
188
  from sklearn.model_selection import train_test_split
189
+
190
+ # If no target column is specified, select the first object column or the last column
191
+ if target_column is None:
192
+ target_column = df.select_dtypes(include=['object']).columns[0] if len(df.select_dtypes(include=['object']).columns) > 0 else df.columns[-1]
193
+
194
  X = df.drop(columns=[target_column])
195
  y = df[target_column]
196
+
197
  return train_test_split(X, y, test_size=0.2, random_state=42)
198
 
199
+
200
  def train_model(_):
201
  try:
202
  wandb.login(key=os.environ.get("WANDB_API_KEY"))