msmaje commited on
Commit
dfd51d1
Β·
verified Β·
1 Parent(s): 4b7e5cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -34
app.py CHANGED
@@ -127,19 +127,19 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
127
  except Exception as e:
128
  raise Exception(f"Error loading dataset: {str(e)}")
129
 
130
- def preview_dataset(file_path, text_column, label_column):
131
  """Preview a dataset file"""
132
  try:
133
- if not file_path:
134
- return "Please select a dataset file first."
135
 
136
- if not os.path.exists(file_path):
137
- return f"❌ File not found: {file_path}"
138
 
139
  df = pd.read_csv(file_path)
140
 
141
  preview_info = []
142
- preview_info.append(f"πŸ“Š **Dataset Preview: {file_path}**")
143
  preview_info.append(f"- **Total rows:** {len(df)}")
144
  preview_info.append(f"- **Columns:** {list(df.columns)}")
145
  preview_info.append("")
@@ -165,6 +165,8 @@ def preview_dataset(file_path, text_column, label_column):
165
 
166
  except Exception as e:
167
  return f"❌ Error previewing dataset: {str(e)}"
 
 
168
  """Login to Hugging Face"""
169
  global TOKEN
170
  TOKEN = token
@@ -308,7 +310,7 @@ def predict_csv(csv_file, model_path):
308
  except Exception as e:
309
  return f"❌ CSV processing failed: {str(e)}"
310
 
311
- def train_model(dataset_file, text_column, label_column, num_epochs, batch_size,
312
  learning_rate, hf_token, push_to_hub, username, model_name):
313
  """Start the model training process with local data"""
314
  global TRAINING_LOGS, MODEL_PATH
@@ -330,15 +332,18 @@ def train_model(dataset_file, text_column, label_column, num_epochs, batch_size,
330
  else:
331
  hub_model_id = None
332
 
333
- # Validate dataset file
334
- if not dataset_file or not os.path.exists(dataset_file):
335
- TRAINING_LOGS.append(f"❌ Dataset file not found: {dataset_file}")
336
  yield "\n".join(TRAINING_LOGS)
337
  return
338
 
 
 
 
339
  try:
340
  # Load and prepare the dataset
341
- TRAINING_LOGS.append(f"πŸ“Š Loading dataset from {dataset_file}...")
342
  yield "\n".join(TRAINING_LOGS)
343
 
344
  dataset_dict, final_text_col, final_label_col = load_and_prepare_local_dataset(
@@ -477,20 +482,15 @@ with gr.Blocks(title="BERT Complaint Classifier") as app:
477
  # Training Tab
478
  with gr.TabItem("Train Model"):
479
  gr.Markdown("### Train a New Model with Local Data")
480
- gr.Markdown("Select your local CSV file and configure training parameters")
481
 
482
- # Dataset selection and preview
483
  with gr.Row():
484
- with gr.Column(scale=2):
485
- dataset_file = gr.Dropdown(
486
- label="Select Dataset File",
487
- choices=[f for f in os.listdir(".") if f.endswith(".csv")],
488
- value=LOCAL_DATA_FILES[0] if LOCAL_DATA_FILES[0] in os.listdir(".") else None,
489
- allow_custom_value=True
490
- )
491
-
492
- with gr.Column(scale=1):
493
- refresh_btn = gr.Button("πŸ”„ Refresh Files", size="sm")
494
 
495
  # Column configuration
496
  with gr.Row():
@@ -507,7 +507,7 @@ with gr.Blocks(title="BERT Complaint Classifier") as app:
507
 
508
  # Dataset preview
509
  preview_btn = gr.Button("πŸ“Š Preview Dataset", variant="secondary")
510
- dataset_preview = gr.Markdown("Select a dataset file and click 'Preview Dataset' to see its structure.")
511
 
512
  # Training parameters
513
  with gr.Row():
@@ -587,22 +587,20 @@ with gr.Blocks(title="BERT Complaint Classifier") as app:
587
  train_btn = gr.Button("Start Training", variant="primary")
588
  training_output = gr.Textbox(label="Training Progress", lines=10)
589
 
590
- # Connect the buttons
591
- post_train_push_btn.click(
592
- push_to_hub_after_training,
593
- inputs=[
594
- gr.Textbox(value=MODEL_PATH, visible=False),
595
- post_train_username,
596
- post_train_model_name,
597
- post_train_token
598
- ],
599
- outputs=post_train_status
600
  )
601
 
 
602
  train_btn.click(
603
  train_model,
604
  inputs=[
605
  dataset_file,
 
 
606
  num_epochs,
607
  batch_size,
608
  learning_rate,
@@ -614,6 +612,18 @@ with gr.Blocks(title="BERT Complaint Classifier") as app:
614
  outputs=training_output,
615
  show_progress="full"
616
  )
 
 
 
 
 
 
 
 
 
 
 
 
617
 
618
  # Classification Tab
619
  with gr.TabItem("Classify Complaints"):
 
127
  except Exception as e:
128
  raise Exception(f"Error loading dataset: {str(e)}")
129
 
130
+ def preview_dataset(uploaded_file, text_column, label_column):
131
  """Preview a dataset file"""
132
  try:
133
+ if uploaded_file is None:
134
+ return "Please upload a dataset file first."
135
 
136
+ # Get the file path from the uploaded file
137
+ file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
138
 
139
  df = pd.read_csv(file_path)
140
 
141
  preview_info = []
142
+ preview_info.append(f"πŸ“Š **Dataset Preview: {os.path.basename(file_path)}**")
143
  preview_info.append(f"- **Total rows:** {len(df)}")
144
  preview_info.append(f"- **Columns:** {list(df.columns)}")
145
  preview_info.append("")
 
165
 
166
  except Exception as e:
167
  return f"❌ Error previewing dataset: {str(e)}"
168
+
169
+ def login_to_hf(token):
170
  """Login to Hugging Face"""
171
  global TOKEN
172
  TOKEN = token
 
310
  except Exception as e:
311
  return f"❌ CSV processing failed: {str(e)}"
312
 
313
+ def train_model(uploaded_file, text_column, label_column, num_epochs, batch_size,
314
  learning_rate, hf_token, push_to_hub, username, model_name):
315
  """Start the model training process with local data"""
316
  global TRAINING_LOGS, MODEL_PATH
 
332
  else:
333
  hub_model_id = None
334
 
335
+ # Validate uploaded file
336
+ if uploaded_file is None:
337
+ TRAINING_LOGS.append("❌ Please upload a dataset file")
338
  yield "\n".join(TRAINING_LOGS)
339
  return
340
 
341
+ # Get the file path from the uploaded file
342
+ dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
343
+
344
  try:
345
  # Load and prepare the dataset
346
+ TRAINING_LOGS.append(f"πŸ“Š Loading dataset from uploaded file...")
347
  yield "\n".join(TRAINING_LOGS)
348
 
349
  dataset_dict, final_text_col, final_label_col = load_and_prepare_local_dataset(
 
482
  # Training Tab
483
  with gr.TabItem("Train Model"):
484
  gr.Markdown("### Train a New Model with Local Data")
485
+ gr.Markdown("Upload your CSV file and configure training parameters")
486
 
487
+ # Dataset upload
488
  with gr.Row():
489
+ dataset_file = gr.File(
490
+ label="Upload Dataset (CSV)",
491
+ file_types=[".csv"],
492
+ type="filepath"
493
+ )
 
 
 
 
 
494
 
495
  # Column configuration
496
  with gr.Row():
 
507
 
508
  # Dataset preview
509
  preview_btn = gr.Button("πŸ“Š Preview Dataset", variant="secondary")
510
+ dataset_preview = gr.Markdown("Upload a dataset file and click 'Preview Dataset' to see its structure.")
511
 
512
  # Training parameters
513
  with gr.Row():
 
587
  train_btn = gr.Button("Start Training", variant="primary")
588
  training_output = gr.Textbox(label="Training Progress", lines=10)
589
 
590
+ # Connect the preview button
591
+ preview_btn.click(
592
+ preview_dataset,
593
+ inputs=[dataset_file, text_column, label_column],
594
+ outputs=dataset_preview
 
 
 
 
 
595
  )
596
 
597
+ # Connect the training button
598
  train_btn.click(
599
  train_model,
600
  inputs=[
601
  dataset_file,
602
+ text_column,
603
+ label_column,
604
  num_epochs,
605
  batch_size,
606
  learning_rate,
 
612
  outputs=training_output,
613
  show_progress="full"
614
  )
615
+
616
+ # Connect the post-training push button
617
+ post_train_push_btn.click(
618
+ push_to_hub_after_training,
619
+ inputs=[
620
+ gr.Textbox(value=MODEL_PATH, visible=False),
621
+ post_train_username,
622
+ post_train_model_name,
623
+ post_train_token
624
+ ],
625
+ outputs=post_train_status
626
+ )
627
 
628
  # Classification Tab
629
  with gr.TabItem("Classify Complaints"):