msmaje commited on
Commit
2b7e143
Β·
verified Β·
1 Parent(s): 3c14fdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -61
app.py CHANGED
@@ -35,16 +35,16 @@ TRAINING_LOGS = []
35
  CURRENT_MODEL = None
36
  CURRENT_TOKENIZER = None
37
 
 
 
38
  def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
39
  """Load and prepare local CSV dataset for training"""
40
  try:
41
  if not os.path.exists(file_path):
42
  raise FileNotFoundError(f"Dataset file not found: {file_path}")
43
 
44
- # Load the CSV file
45
  df = pd.read_csv(file_path)
46
 
47
- # Verify required columns exist
48
  if text_column not in df.columns:
49
  available_cols = list(df.columns)
50
  raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}")
@@ -53,24 +53,19 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
53
  available_cols = list(df.columns)
54
  raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}")
55
 
56
- # Clean the data
57
  df = df.dropna(subset=[text_column, label_column])
58
  df[text_column] = df[text_column].astype(str)
59
 
60
- # Handle different label formats
61
  if df[label_column].dtype == 'object':
62
- # If labels are text, convert to indices
63
  unique_labels = df[label_column].unique()
64
  if len(unique_labels) > len(CATEGORIES):
65
  raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}")
66
 
67
- # Try to map text labels to our categories
68
  label_mapping = {}
69
  for label in unique_labels:
70
  if label in category_to_idx:
71
  label_mapping[label] = category_to_idx[label]
72
  else:
73
- # Auto-assign if not found
74
  available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values())
75
  if available_indices:
76
  label_mapping[label] = min(available_indices)
@@ -79,14 +74,11 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
79
 
80
  df['label_idx'] = df[label_column].map(label_mapping)
81
  else:
82
- # If labels are already numeric
83
  df['label_idx'] = df[label_column].astype(int)
84
 
85
- # Verify label indices are valid
86
  if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES):
87
  raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
88
 
89
- # Create train/validation split
90
  train_df, val_df = train_test_split(
91
  df,
92
  test_size=test_size,
@@ -94,7 +86,6 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
94
  stratify=df['label_idx']
95
  )
96
 
97
- # Convert to Hugging Face datasets
98
  train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']])
99
  val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']])
100
 
@@ -114,7 +105,6 @@ def preview_dataset(uploaded_file, text_column, label_column):
114
  if uploaded_file is None:
115
  return "Please upload a dataset file first."
116
 
117
- # Get the file path from the uploaded file
118
  file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
119
 
120
  df = pd.read_csv(file_path)
@@ -162,11 +152,9 @@ def validate_hub_model_id(username, model_name):
162
  if not username or not model_name:
163
  return None, "Please provide both username and model name"
164
 
165
- # Clean up the model name
166
  model_name = model_name.strip().lower().replace(" ", "-")
167
  model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_'])
168
 
169
- # Construct the full model ID
170
  hub_model_id = f"{username}/{model_name}"
171
 
172
  return hub_model_id, None
@@ -176,7 +164,6 @@ def load_model(model_path):
176
  global CURRENT_MODEL, CURRENT_TOKENIZER
177
 
178
  try:
179
- # Try loading from local path first
180
  if os.path.exists(model_path):
181
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
182
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
@@ -185,7 +172,6 @@ def load_model(model_path):
185
  )
186
  return f"βœ… Model loaded from local path: {model_path}"
187
 
188
- # If local path doesn't exist, try loading from Hub
189
  try:
190
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
191
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
@@ -194,7 +180,6 @@ def load_model(model_path):
194
  )
195
  return f"βœ… Model loaded from Hugging Face Hub: {model_path}"
196
  except Exception as hub_error:
197
- # If both local and hub loading fail, fall back to base model
198
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
199
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
200
  "bert-base-uncased",
@@ -242,7 +227,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
242
  TRAINING_LOGS.append(login_result)
243
  yield "\n".join(TRAINING_LOGS)
244
 
245
- # Validate hub model ID if pushing to hub
246
  if push_to_hub:
247
  hub_model_id, error = validate_hub_model_id(username, model_name)
248
  if error:
@@ -252,7 +236,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
252
  else:
253
  hub_model_id = None
254
 
255
- # Validate uploaded file
256
  if uploaded_file is None:
257
  TRAINING_LOGS.append("❌ Please upload a dataset file")
258
  yield "\n".join(TRAINING_LOGS)
@@ -261,7 +244,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
261
  dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
262
 
263
  try:
264
- # Load and prepare dataset
265
  TRAINING_LOGS.append(f"πŸ“Š Loading dataset from uploaded file...")
266
  yield "\n".join(TRAINING_LOGS)
267
 
@@ -274,7 +256,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
274
  TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
275
  yield "\n".join(TRAINING_LOGS)
276
 
277
- # Load model and tokenizer
278
  TRAINING_LOGS.append("πŸ€– Loading BERT model and tokenizer...")
279
  yield "\n".join(TRAINING_LOGS)
280
 
@@ -287,14 +268,12 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
287
  TRAINING_LOGS.append("βœ… Model and tokenizer loaded")
288
  yield "\n".join(TRAINING_LOGS)
289
 
290
- # Tokenize datasets
291
  TRAINING_LOGS.append("πŸ”€ Tokenizing datasets...")
292
  yield "\n".join(TRAINING_LOGS)
293
 
294
  def tokenize_batch(examples):
295
  return tokenize_function(examples, tokenizer, final_text_col, 512)
296
 
297
- # Get columns to remove (keep only label column and tokenized features)
298
  columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
299
 
300
  tokenized_datasets = dataset_dict.map(
@@ -303,17 +282,14 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
303
  remove_columns=columns_to_remove
304
  )
305
 
306
- # Rename label column to 'labels' (required by Trainer)
307
  tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
308
 
309
  TRAINING_LOGS.append("βœ… Tokenization completed")
310
  yield "\n".join(TRAINING_LOGS)
311
 
312
- # Set up training
313
  output_dir = Path(MODEL_PATH)
314
  output_dir.mkdir(parents=True, exist_ok=True)
315
 
316
- # Calculate steps
317
  total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
318
  eval_steps = max(10, min(100, total_steps // 4))
319
  save_steps = max(20, min(500, total_steps // 2))
@@ -326,7 +302,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
326
  TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
327
  yield "\n".join(TRAINING_LOGS)
328
 
329
- # Training arguments
330
  training_args = TrainingArguments(
331
  output_dir=str(output_dir),
332
  num_train_epochs=num_epochs,
@@ -353,10 +328,8 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
353
  remove_unused_columns=False,
354
  )
355
 
356
- # Data collator
357
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
358
 
359
- # Create trainer
360
  trainer = Trainer(
361
  model=model,
362
  args=training_args,
@@ -371,7 +344,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
371
  TRAINING_LOGS.append("πŸš€ Starting training...")
372
  yield "\n".join(TRAINING_LOGS)
373
 
374
- # Custom training loop with progress updates
375
  class ProgressCallback:
376
  def __init__(self, logs_list):
377
  self.logs = logs_list
@@ -395,7 +367,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
395
  progress_callback = ProgressCallback(TRAINING_LOGS)
396
  trainer.add_callback(progress_callback)
397
 
398
- # Train the model
399
  try:
400
  trainer.train()
401
  TRAINING_LOGS.append("βœ… Training completed successfully!")
@@ -405,21 +376,18 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
405
  yield "\n".join(TRAINING_LOGS)
406
  return
407
 
408
- # Save the model
409
  TRAINING_LOGS.append("πŸ’Ύ Saving model...")
410
  yield "\n".join(TRAINING_LOGS)
411
 
412
  trainer.save_model()
413
  tokenizer.save_pretrained(output_dir)
414
 
415
- # Update global model and tokenizer
416
  CURRENT_MODEL = model
417
  CURRENT_TOKENIZER = tokenizer
418
 
419
  TRAINING_LOGS.append("βœ… Model saved successfully!")
420
  yield "\n".join(TRAINING_LOGS)
421
 
422
- # Final evaluation
423
  TRAINING_LOGS.append("πŸ“Š Running final evaluation...")
424
  yield "\n".join(TRAINING_LOGS)
425
 
@@ -432,7 +400,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
432
  else:
433
  TRAINING_LOGS.append(f" {key}: {value}")
434
 
435
- # Save results
436
  with open(output_dir / "eval_results.json", "w") as f:
437
  json.dump(eval_results, f, indent=2)
438
 
@@ -441,7 +408,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
441
 
442
  yield "\n".join(TRAINING_LOGS)
443
 
444
- # Push to hub if requested
445
  if push_to_hub and hub_model_id:
446
  TRAINING_LOGS.append(f"πŸ€— Pushing to Hugging Face Hub: {hub_model_id}")
447
  yield "\n".join(TRAINING_LOGS)
@@ -465,7 +431,6 @@ def predict_text(text, model_path):
465
  """Make a prediction on a single text input"""
466
  global CURRENT_MODEL, CURRENT_TOKENIZER
467
 
468
- # Load the model if it's not loaded or a different one is requested
469
  if CURRENT_MODEL is None or model_path != MODEL_PATH:
470
  load_result = load_model(model_path)
471
  if load_result.startswith("❌"):
@@ -475,24 +440,19 @@ def predict_text(text, model_path):
475
  if not text.strip():
476
  return "Please enter some text to classify."
477
 
478
- # Check if text was truncated
479
  original_tokens = CURRENT_TOKENIZER(text, truncation=False)
480
  was_truncated = len(original_tokens['input_ids']) > 512
481
 
482
- # Tokenize input
483
  inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
484
 
485
- # Make prediction
486
  with torch.no_grad():
487
  outputs = CURRENT_MODEL(**inputs)
488
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
489
  predicted_class_id = predictions.argmax().item()
490
  confidence = predictions.max().item()
491
 
492
- # Get predicted category
493
  predicted_category = idx_to_category[predicted_class_id]
494
 
495
- # Format result
496
  truncation_warning = "\n\n⚠️ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
497
 
498
  result = []
@@ -516,21 +476,19 @@ def predict_csv(csv_file, model_path):
516
  """Make predictions on a CSV file with complaints"""
517
  global CURRENT_MODEL, CURRENT_TOKENIZER
518
 
519
- # Load the model if needed
520
  if CURRENT_MODEL is None or model_path != MODEL_PATH:
521
  load_result = load_model(model_path)
522
  if load_result.startswith("❌"):
523
- return load_result
524
 
525
  try:
526
- # Read the CSV file
527
  if hasattr(csv_file, 'name'):
528
  df = pd.read_csv(csv_file.name)
529
  else:
530
  df = pd.read_csv(csv_file)
531
 
532
  if 'complaint' not in df.columns:
533
- return "❌ CSV file must have a 'complaint' column"
534
 
535
  results = []
536
  predictions_list = []
@@ -539,13 +497,11 @@ def predict_csv(csv_file, model_path):
539
  for i, row in enumerate(df.iterrows()):
540
  complaint = str(row[1]['complaint'])
541
 
542
- # Check for truncation
543
  original_tokens = CURRENT_TOKENIZER(complaint, truncation=False)
544
  was_truncated = len(original_tokens['input_ids']) > 512
545
  if was_truncated:
546
  truncated_count += 1
547
 
548
- # Predict
549
  inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
550
  with torch.no_grad():
551
  outputs = CURRENT_MODEL(**inputs)
@@ -573,16 +529,15 @@ def predict_csv(csv_file, model_path):
573
  if truncated_count > 0:
574
  results.append(f"\n⚠️ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
575
 
576
- # Save full results to a CSV file
577
  results_df = pd.DataFrame(predictions_list)
578
  results_file = "prediction_results.csv"
579
  results_df.to_csv(results_file, index=False)
580
  results.append(f"\nπŸ’Ύ Full results saved to {results_file}")
581
 
582
- return "\n".join(results)
583
 
584
  except Exception as e:
585
- return f"❌ CSV processing failed: {str(e)}"
586
 
587
  def push_to_hub_after_training(model_path, username, model_name, token):
588
  """Push a trained model to Hugging Face Hub"""
@@ -594,7 +549,6 @@ def push_to_hub_after_training(model_path, username, model_name, token):
594
  if error:
595
  return f"❌ {error}"
596
 
597
- # Login and load model
598
  login(token)
599
  if not os.path.exists(model_path):
600
  return "❌ No trained model found. Please train a model first."
@@ -605,7 +559,6 @@ def push_to_hub_after_training(model_path, username, model_name, token):
605
  except Exception as e:
606
  return f"❌ Failed to load model: {str(e)}"
607
 
608
- # Push to Hub
609
  try:
610
  model.push_to_hub(hub_model_id)
611
  tokenizer.push_to_hub(hub_model_id)
@@ -650,6 +603,8 @@ def display_available_datasets():
650
  else:
651
  return "No CSV files found in the current directory."
652
 
 
 
653
  # Initialize tokenizer on startup
654
  if CURRENT_TOKENIZER is None:
655
  try:
@@ -661,7 +616,7 @@ if CURRENT_TOKENIZER is None:
661
  print("πŸš€ Launching BERT Complaint Classifier...")
662
  print("πŸ“ Available at: http://localhost:7860")
663
 
664
- # The entire Gradio UI definition must be within a single block
665
  with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app:
666
  gr.Markdown("# BERT Complaint Classifier πŸ—£οΈπŸ€–")
667
  gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.")
@@ -724,7 +679,6 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
724
  predict_btn = gr.Button("Classify Complaint", variant="primary")
725
  single_prediction_output = gr.Markdown("Prediction will appear here...")
726
 
727
- # Link token count to text input
728
  text_input.change(count_tokens, inputs=text_input, outputs=token_count_output)
729
 
730
  with gr.Tab("Predict from CSV"):
@@ -740,7 +694,6 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
740
  csv_prediction_output = gr.Markdown("Predictions will appear here...")
741
  download_link = gr.File(label="Download Full Predictions", interactive=False)
742
 
743
- # Link prediction buttons to functions
744
  predict_btn.click(
745
  predict_text,
746
  inputs=[text_input, model_path_input],
@@ -797,16 +750,14 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
797
  hub_token_input_push = gr.Textbox(label="Hugging Face Token", type="password")
798
 
799
  push_btn = gr.Button("πŸš€ Push Model to Hub", variant="primary")
800
- push_output = gr.verse("Results will appear here...")
801
 
802
- # Link the push button
803
  push_btn.click(
804
  push_to_hub_after_training,
805
  inputs=[gr.Textbox(value=MODEL_PATH, visible=False), hub_username_input_push, hub_model_name_input_push, hub_token_input_push],
806
  outputs=push_output
807
  )
808
 
809
- # All button clicks and UI logic now correctly indented within the app block
810
  preview_btn.click(
811
  preview_dataset,
812
  inputs=[uploaded_file, text_column_input, label_column_input],
@@ -835,10 +786,8 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
835
  outputs=available_datasets
836
  )
837
 
838
- # Show datasets on load
839
  app.load(display_available_datasets, outputs=available_datasets)
840
 
841
- # Launch the app
842
  if __name__ == "__main__":
843
  app.launch(
844
  server_name="0.0.0.0",
 
35
  CURRENT_MODEL = None
36
  CURRENT_TOKENIZER = None
37
 
38
+ # --- Application Logic Functions (No change needed here, they are correctly indented) ---
39
+
40
  def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
41
  """Load and prepare local CSV dataset for training"""
42
  try:
43
  if not os.path.exists(file_path):
44
  raise FileNotFoundError(f"Dataset file not found: {file_path}")
45
 
 
46
  df = pd.read_csv(file_path)
47
 
 
48
  if text_column not in df.columns:
49
  available_cols = list(df.columns)
50
  raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}")
 
53
  available_cols = list(df.columns)
54
  raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}")
55
 
 
56
  df = df.dropna(subset=[text_column, label_column])
57
  df[text_column] = df[text_column].astype(str)
58
 
 
59
  if df[label_column].dtype == 'object':
 
60
  unique_labels = df[label_column].unique()
61
  if len(unique_labels) > len(CATEGORIES):
62
  raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}")
63
 
 
64
  label_mapping = {}
65
  for label in unique_labels:
66
  if label in category_to_idx:
67
  label_mapping[label] = category_to_idx[label]
68
  else:
 
69
  available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values())
70
  if available_indices:
71
  label_mapping[label] = min(available_indices)
 
74
 
75
  df['label_idx'] = df[label_column].map(label_mapping)
76
  else:
 
77
  df['label_idx'] = df[label_column].astype(int)
78
 
 
79
  if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES):
80
  raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
81
 
 
82
  train_df, val_df = train_test_split(
83
  df,
84
  test_size=test_size,
 
86
  stratify=df['label_idx']
87
  )
88
 
 
89
  train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']])
90
  val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']])
91
 
 
105
  if uploaded_file is None:
106
  return "Please upload a dataset file first."
107
 
 
108
  file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
109
 
110
  df = pd.read_csv(file_path)
 
152
  if not username or not model_name:
153
  return None, "Please provide both username and model name"
154
 
 
155
  model_name = model_name.strip().lower().replace(" ", "-")
156
  model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_'])
157
 
 
158
  hub_model_id = f"{username}/{model_name}"
159
 
160
  return hub_model_id, None
 
164
  global CURRENT_MODEL, CURRENT_TOKENIZER
165
 
166
  try:
 
167
  if os.path.exists(model_path):
168
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
169
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
 
172
  )
173
  return f"βœ… Model loaded from local path: {model_path}"
174
 
 
175
  try:
176
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
177
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
 
180
  )
181
  return f"βœ… Model loaded from Hugging Face Hub: {model_path}"
182
  except Exception as hub_error:
 
183
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
184
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
185
  "bert-base-uncased",
 
227
  TRAINING_LOGS.append(login_result)
228
  yield "\n".join(TRAINING_LOGS)
229
 
 
230
  if push_to_hub:
231
  hub_model_id, error = validate_hub_model_id(username, model_name)
232
  if error:
 
236
  else:
237
  hub_model_id = None
238
 
 
239
  if uploaded_file is None:
240
  TRAINING_LOGS.append("❌ Please upload a dataset file")
241
  yield "\n".join(TRAINING_LOGS)
 
244
  dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
245
 
246
  try:
 
247
  TRAINING_LOGS.append(f"πŸ“Š Loading dataset from uploaded file...")
248
  yield "\n".join(TRAINING_LOGS)
249
 
 
256
  TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
257
  yield "\n".join(TRAINING_LOGS)
258
 
 
259
  TRAINING_LOGS.append("πŸ€– Loading BERT model and tokenizer...")
260
  yield "\n".join(TRAINING_LOGS)
261
 
 
268
  TRAINING_LOGS.append("βœ… Model and tokenizer loaded")
269
  yield "\n".join(TRAINING_LOGS)
270
 
 
271
  TRAINING_LOGS.append("πŸ”€ Tokenizing datasets...")
272
  yield "\n".join(TRAINING_LOGS)
273
 
274
  def tokenize_batch(examples):
275
  return tokenize_function(examples, tokenizer, final_text_col, 512)
276
 
 
277
  columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
278
 
279
  tokenized_datasets = dataset_dict.map(
 
282
  remove_columns=columns_to_remove
283
  )
284
 
 
285
  tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
286
 
287
  TRAINING_LOGS.append("βœ… Tokenization completed")
288
  yield "\n".join(TRAINING_LOGS)
289
 
 
290
  output_dir = Path(MODEL_PATH)
291
  output_dir.mkdir(parents=True, exist_ok=True)
292
 
 
293
  total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
294
  eval_steps = max(10, min(100, total_steps // 4))
295
  save_steps = max(20, min(500, total_steps // 2))
 
302
  TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
303
  yield "\n".join(TRAINING_LOGS)
304
 
 
305
  training_args = TrainingArguments(
306
  output_dir=str(output_dir),
307
  num_train_epochs=num_epochs,
 
328
  remove_unused_columns=False,
329
  )
330
 
 
331
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
332
 
 
333
  trainer = Trainer(
334
  model=model,
335
  args=training_args,
 
344
  TRAINING_LOGS.append("πŸš€ Starting training...")
345
  yield "\n".join(TRAINING_LOGS)
346
 
 
347
  class ProgressCallback:
348
  def __init__(self, logs_list):
349
  self.logs = logs_list
 
367
  progress_callback = ProgressCallback(TRAINING_LOGS)
368
  trainer.add_callback(progress_callback)
369
 
 
370
  try:
371
  trainer.train()
372
  TRAINING_LOGS.append("βœ… Training completed successfully!")
 
376
  yield "\n".join(TRAINING_LOGS)
377
  return
378
 
 
379
  TRAINING_LOGS.append("πŸ’Ύ Saving model...")
380
  yield "\n".join(TRAINING_LOGS)
381
 
382
  trainer.save_model()
383
  tokenizer.save_pretrained(output_dir)
384
 
 
385
  CURRENT_MODEL = model
386
  CURRENT_TOKENIZER = tokenizer
387
 
388
  TRAINING_LOGS.append("βœ… Model saved successfully!")
389
  yield "\n".join(TRAINING_LOGS)
390
 
 
391
  TRAINING_LOGS.append("πŸ“Š Running final evaluation...")
392
  yield "\n".join(TRAINING_LOGS)
393
 
 
400
  else:
401
  TRAINING_LOGS.append(f" {key}: {value}")
402
 
 
403
  with open(output_dir / "eval_results.json", "w") as f:
404
  json.dump(eval_results, f, indent=2)
405
 
 
408
 
409
  yield "\n".join(TRAINING_LOGS)
410
 
 
411
  if push_to_hub and hub_model_id:
412
  TRAINING_LOGS.append(f"πŸ€— Pushing to Hugging Face Hub: {hub_model_id}")
413
  yield "\n".join(TRAINING_LOGS)
 
431
  """Make a prediction on a single text input"""
432
  global CURRENT_MODEL, CURRENT_TOKENIZER
433
 
 
434
  if CURRENT_MODEL is None or model_path != MODEL_PATH:
435
  load_result = load_model(model_path)
436
  if load_result.startswith("❌"):
 
440
  if not text.strip():
441
  return "Please enter some text to classify."
442
 
 
443
  original_tokens = CURRENT_TOKENIZER(text, truncation=False)
444
  was_truncated = len(original_tokens['input_ids']) > 512
445
 
 
446
  inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
447
 
 
448
  with torch.no_grad():
449
  outputs = CURRENT_MODEL(**inputs)
450
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
451
  predicted_class_id = predictions.argmax().item()
452
  confidence = predictions.max().item()
453
 
 
454
  predicted_category = idx_to_category[predicted_class_id]
455
 
 
456
  truncation_warning = "\n\n⚠️ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
457
 
458
  result = []
 
476
  """Make predictions on a CSV file with complaints"""
477
  global CURRENT_MODEL, CURRENT_TOKENIZER
478
 
 
479
  if CURRENT_MODEL is None or model_path != MODEL_PATH:
480
  load_result = load_model(model_path)
481
  if load_result.startswith("❌"):
482
+ return load_result, None
483
 
484
  try:
 
485
  if hasattr(csv_file, 'name'):
486
  df = pd.read_csv(csv_file.name)
487
  else:
488
  df = pd.read_csv(csv_file)
489
 
490
  if 'complaint' not in df.columns:
491
+ return "❌ CSV file must have a 'complaint' column", None
492
 
493
  results = []
494
  predictions_list = []
 
497
  for i, row in enumerate(df.iterrows()):
498
  complaint = str(row[1]['complaint'])
499
 
 
500
  original_tokens = CURRENT_TOKENIZER(complaint, truncation=False)
501
  was_truncated = len(original_tokens['input_ids']) > 512
502
  if was_truncated:
503
  truncated_count += 1
504
 
 
505
  inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
506
  with torch.no_grad():
507
  outputs = CURRENT_MODEL(**inputs)
 
529
  if truncated_count > 0:
530
  results.append(f"\n⚠️ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
531
 
 
532
  results_df = pd.DataFrame(predictions_list)
533
  results_file = "prediction_results.csv"
534
  results_df.to_csv(results_file, index=False)
535
  results.append(f"\nπŸ’Ύ Full results saved to {results_file}")
536
 
537
+ return "\n".join(results), results_file
538
 
539
  except Exception as e:
540
+ return f"❌ CSV processing failed: {str(e)}", None
541
 
542
  def push_to_hub_after_training(model_path, username, model_name, token):
543
  """Push a trained model to Hugging Face Hub"""
 
549
  if error:
550
  return f"❌ {error}"
551
 
 
552
  login(token)
553
  if not os.path.exists(model_path):
554
  return "❌ No trained model found. Please train a model first."
 
559
  except Exception as e:
560
  return f"❌ Failed to load model: {str(e)}"
561
 
 
562
  try:
563
  model.push_to_hub(hub_model_id)
564
  tokenizer.push_to_hub(hub_model_id)
 
603
  else:
604
  return "No CSV files found in the current directory."
605
 
606
+ # --- Gradio UI Definition (Correctly structured) ---
607
+
608
  # Initialize tokenizer on startup
609
  if CURRENT_TOKENIZER is None:
610
  try:
 
616
  print("πŸš€ Launching BERT Complaint Classifier...")
617
  print("πŸ“ Available at: http://localhost:7860")
618
 
619
+ # The entire Gradio UI definition must be within this single block
620
  with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app:
621
  gr.Markdown("# BERT Complaint Classifier πŸ—£οΈπŸ€–")
622
  gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.")
 
679
  predict_btn = gr.Button("Classify Complaint", variant="primary")
680
  single_prediction_output = gr.Markdown("Prediction will appear here...")
681
 
 
682
  text_input.change(count_tokens, inputs=text_input, outputs=token_count_output)
683
 
684
  with gr.Tab("Predict from CSV"):
 
694
  csv_prediction_output = gr.Markdown("Predictions will appear here...")
695
  download_link = gr.File(label="Download Full Predictions", interactive=False)
696
 
 
697
  predict_btn.click(
698
  predict_text,
699
  inputs=[text_input, model_path_input],
 
750
  hub_token_input_push = gr.Textbox(label="Hugging Face Token", type="password")
751
 
752
  push_btn = gr.Button("πŸš€ Push Model to Hub", variant="primary")
753
+ push_output = gr.Textbox(label="Results", lines=3, interactive=False)
754
 
 
755
  push_btn.click(
756
  push_to_hub_after_training,
757
  inputs=[gr.Textbox(value=MODEL_PATH, visible=False), hub_username_input_push, hub_model_name_input_push, hub_token_input_push],
758
  outputs=push_output
759
  )
760
 
 
761
  preview_btn.click(
762
  preview_dataset,
763
  inputs=[uploaded_file, text_column_input, label_column_input],
 
786
  outputs=available_datasets
787
  )
788
 
 
789
  app.load(display_available_datasets, outputs=available_datasets)
790
 
 
791
  if __name__ == "__main__":
792
  app.launch(
793
  server_name="0.0.0.0",