msmaje commited on
Commit
d58a542
·
verified ·
1 Parent(s): dfd51d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +403 -384
app.py CHANGED
@@ -2,13 +2,28 @@ import gradio as gr
2
  import torch
3
  import pandas as pd
4
  import os
5
- import tempfile
6
- import time
7
- import subprocess
8
  import json
 
 
 
 
 
 
 
9
  from huggingface_hub import login, HfApi
10
- from transformers import AutoTokenizer, BertForSequenceClassification
11
- from datasets import load_dataset, Dataset, DatasetDict
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Global variables
14
  MODEL_PATH = "local-model"
@@ -20,38 +35,6 @@ TRAINING_LOGS = []
20
  CURRENT_MODEL = None
21
  CURRENT_TOKENIZER = None
22
 
23
- # Local data files
24
- LOCAL_DATA_FILES = [
25
- "merged-test-data.csv",
26
- "test-category.csv",
27
- "test-complaint.csv"
28
- ]
29
-
30
- def get_available_datasets():
31
- """Get list of available local datasets"""
32
- available_files = []
33
- for file in LOCAL_DATA_FILES:
34
- if os.path.exists(file):
35
- try:
36
- df = pd.read_csv(file)
37
- available_files.append(f"{file} ({len(df)} rows)")
38
- except Exception as e:
39
- available_files.append(f"{file} (Error: {str(e)})")
40
- else:
41
- available_files.append(f"{file} (Not found)")
42
-
43
- # Also check for any other CSV files in the directory
44
- for file in os.listdir("."):
45
- if file.endswith(".csv") and file not in LOCAL_DATA_FILES:
46
- if os.path.exists(file):
47
- try:
48
- df = pd.read_csv(file)
49
- available_files.append(f"{file} ({len(df)} rows)")
50
- except:
51
- available_files.append(f"{file} (Error reading)")
52
-
53
- return available_files
54
-
55
  def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
56
  """Load and prepare local CSV dataset for training"""
57
  try:
@@ -104,8 +87,6 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
104
  raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
105
 
106
  # Create train/validation split
107
- from sklearn.model_selection import train_test_split
108
-
109
  train_df, val_df = train_test_split(
110
  df,
111
  test_size=test_size,
@@ -224,6 +205,262 @@ def load_model(model_path):
224
  except Exception as e:
225
  return f"❌ Failed to load model: {str(e)}"
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  def predict_text(text, model_path):
228
  """Make a prediction on a single text input"""
229
  global CURRENT_MODEL, CURRENT_TOKENIZER
@@ -235,25 +472,45 @@ def predict_text(text, model_path):
235
  return load_result
236
 
237
  try:
 
 
 
 
 
 
 
238
  # Tokenize input
239
  inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
240
 
241
  # Make prediction
242
  with torch.no_grad():
243
  outputs = CURRENT_MODEL(**inputs)
244
- predicted_idx = outputs.logits.argmax().item()
 
 
245
 
246
- # Get category from index
247
- predicted_category = idx_to_category[predicted_idx]
248
 
249
- # Check if text was truncated
250
- original_tokens = CURRENT_TOKENIZER(text, truncation=False)
251
- was_truncated = len(original_tokens['input_ids']) > 512
252
  truncation_warning = "\n\n⚠️ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
253
 
254
- return f"Complaint: {text}\n\nPredicted Category: {predicted_category}{truncation_warning}"
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  except Exception as e:
256
- return f"❌ Prediction failed: {str(e)}"
257
 
258
  def predict_csv(csv_file, model_path):
259
  """Make predictions on a CSV file with complaints"""
@@ -276,6 +533,7 @@ def predict_csv(csv_file, model_path):
276
  return "❌ CSV file must have a 'complaint' column"
277
 
278
  results = []
 
279
  truncated_count = 0
280
 
281
  for i, row in enumerate(df.iterrows()):
@@ -291,13 +549,22 @@ def predict_csv(csv_file, model_path):
291
  inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
292
  with torch.no_grad():
293
  outputs = CURRENT_MODEL(**inputs)
294
- predicted_idx = outputs.logits.argmax().item()
 
 
295
 
296
  predicted_category = idx_to_category[predicted_idx]
 
 
 
 
 
 
297
 
298
  truncation_mark = " ⚠️" if was_truncated else ""
299
  preview = complaint if len(complaint) <= 50 else complaint[:47] + "..."
300
- results.append(f"Complaint {i+1}{truncation_mark}: {preview}\nPredicted Category: {predicted_category}\n")
 
301
 
302
  if i >= 19:
303
  results.append(f"... and {len(df) - 20} more (showing first 20 out of {len(df)} complaints)")
@@ -306,141 +573,17 @@ def predict_csv(csv_file, model_path):
306
  if truncated_count > 0:
307
  results.append(f"\n⚠️ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
308
 
 
 
 
 
 
 
309
  return "\n".join(results)
 
310
  except Exception as e:
311
  return f"❌ CSV processing failed: {str(e)}"
312
 
313
- def train_model(uploaded_file, text_column, label_column, num_epochs, batch_size,
314
- learning_rate, hf_token, push_to_hub, username, model_name):
315
- """Start the model training process with local data"""
316
- global TRAINING_LOGS, MODEL_PATH
317
-
318
- TRAINING_LOGS = [] # Reset logs at the start of training
319
-
320
- if hf_token:
321
- login_result = login_to_hf(hf_token)
322
- TRAINING_LOGS.append(login_result)
323
- yield "\n".join(TRAINING_LOGS)
324
-
325
- # Validate hub model ID if pushing to hub
326
- if push_to_hub:
327
- hub_model_id, error = validate_hub_model_id(username, model_name)
328
- if error:
329
- TRAINING_LOGS.append(f"❌ {error}")
330
- yield "\n".join(TRAINING_LOGS)
331
- return
332
- else:
333
- hub_model_id = None
334
-
335
- # Validate uploaded file
336
- if uploaded_file is None:
337
- TRAINING_LOGS.append("❌ Please upload a dataset file")
338
- yield "\n".join(TRAINING_LOGS)
339
- return
340
-
341
- # Get the file path from the uploaded file
342
- dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
343
-
344
- try:
345
- # Load and prepare the dataset
346
- TRAINING_LOGS.append(f"📊 Loading dataset from uploaded file...")
347
- yield "\n".join(TRAINING_LOGS)
348
-
349
- dataset_dict, final_text_col, final_label_col = load_and_prepare_local_dataset(
350
- dataset_file, text_column, label_column
351
- )
352
-
353
- TRAINING_LOGS.append(f"✅ Dataset loaded successfully!")
354
- TRAINING_LOGS.append(f"- Train samples: {len(dataset_dict['train'])}")
355
- TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
356
- yield "\n".join(TRAINING_LOGS)
357
-
358
- # Save dataset temporarily for the training script
359
- temp_dataset_path = "temp_dataset"
360
- os.makedirs(temp_dataset_path, exist_ok=True)
361
- dataset_dict.save_to_disk(temp_dataset_path)
362
-
363
- TRAINING_LOGS.append("💾 Dataset prepared for training...")
364
- yield "\n".join(TRAINING_LOGS)
365
-
366
- except Exception as e:
367
- TRAINING_LOGS.append(f"❌ Error preparing dataset: {str(e)}")
368
- yield "\n".join(TRAINING_LOGS)
369
- return
370
-
371
- # Create training command for local dataset
372
- cmd = [
373
- "python", "bert_finetune.py",
374
- "--dataset_path", temp_dataset_path, # Use local path instead of HF dataset name
375
- "--model_id", "bert-base-uncased",
376
- "--output_dir", MODEL_PATH,
377
- "--feature_column", final_text_col,
378
- "--label_column", final_label_col,
379
- "--num_labels", "3",
380
- "--num_train_epochs", str(num_epochs),
381
- "--batch_size", str(batch_size),
382
- "--learning_rate", str(learning_rate),
383
- "--max_length", "512"
384
- ]
385
-
386
- if push_to_hub and hub_model_id:
387
- cmd.extend(["--push_to_hub", "--hub_model_id", hub_model_id])
388
- if hf_token:
389
- cmd.extend(["--hf_token", hf_token])
390
-
391
- TRAINING_LOGS.append(f"🚀 Starting training with command: {' '.join(cmd)}")
392
- yield "\n".join(TRAINING_LOGS)
393
-
394
- try:
395
- process = subprocess.Popen(
396
- cmd,
397
- stdout=subprocess.PIPE,
398
- stderr=subprocess.STDOUT,
399
- universal_newlines=True,
400
- bufsize=1
401
- )
402
-
403
- TRAINING_LOGS.append("🔄 Training started...")
404
- yield "\n".join(TRAINING_LOGS)
405
-
406
- while True:
407
- line = process.stdout.readline()
408
- if not line and process.poll() is not None:
409
- break
410
- if line:
411
- TRAINING_LOGS.append(line.strip())
412
- yield "\n".join(TRAINING_LOGS)
413
-
414
- process.wait()
415
-
416
- if process.returncode == 0:
417
- TRAINING_LOGS.append("✅ Training completed successfully!")
418
- if push_to_hub and hub_model_id:
419
- TRAINING_LOGS.append(f"🤗 Model pushed to Hugging Face Hub: {hub_model_id}")
420
-
421
- # Load the trained model
422
- TRAINING_LOGS.append("📥 Loading trained model...")
423
- load_result = load_model(MODEL_PATH)
424
- TRAINING_LOGS.append(load_result)
425
-
426
- # Clean up temporary files
427
- import shutil
428
- try:
429
- shutil.rmtree(temp_dataset_path)
430
- TRAINING_LOGS.append("🧹 Cleaned up temporary files")
431
- except:
432
- pass
433
-
434
- # Final success message
435
- TRAINING_LOGS.append("\n✨ All done! Your model is ready to use.")
436
- else:
437
- TRAINING_LOGS.append(f"❌ Training failed with return code {process.returncode}")
438
-
439
- except Exception as e:
440
- TRAINING_LOGS.append(f"❌ Error during training: {str(e)}")
441
-
442
- yield "\n".join(TRAINING_LOGS)
443
-
444
  def push_to_hub_after_training(model_path, username, model_name, token):
445
  """Push a trained model to Hugging Face Hub"""
446
  try:
@@ -473,220 +616,96 @@ def push_to_hub_after_training(model_path, username, model_name, token):
473
  except Exception as e:
474
  return f"❌ Error: {str(e)}"
475
 
476
- # Create the Gradio Interface
477
- with gr.Blocks(title="BERT Complaint Classifier") as app:
478
- gr.Markdown("# BERT Complaint Category Classifier")
479
- gr.Markdown("A simple tool to train and use a BERT model for classifying customer complaints")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
 
481
- with gr.Tabs():
482
- # Training Tab
483
- with gr.TabItem("Train Model"):
484
- gr.Markdown("### Train a New Model with Local Data")
485
- gr.Markdown("Upload your CSV file and configure training parameters")
486
-
487
- # Dataset upload
488
- with gr.Row():
489
- dataset_file = gr.File(
490
- label="Upload Dataset (CSV)",
491
- file_types=[".csv"],
492
- type="filepath"
493
- )
494
-
495
- # Column configuration
496
- with gr.Row():
497
- text_column = gr.Textbox(
498
- label="Text Column Name",
499
- value="complaint",
500
- placeholder="e.g., complaint, text, description"
501
- )
502
- label_column = gr.Textbox(
503
- label="Label Column Name",
504
- value="category",
505
- placeholder="e.g., category, label, class"
506
- )
507
-
508
- # Dataset preview
509
- preview_btn = gr.Button("📊 Preview Dataset", variant="secondary")
510
- dataset_preview = gr.Markdown("Upload a dataset file and click 'Preview Dataset' to see its structure.")
511
-
512
- # Training parameters
513
- with gr.Row():
514
- num_epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Epochs")
515
- batch_size = gr.Slider(minimum=4, maximum=32, value=8, step=4, label="Batch Size")
516
- learning_rate = gr.Slider(minimum=1e-5, maximum=5e-5, value=2e-5, step=1e-5, label="Learning Rate")
517
-
518
- with gr.Accordion("Hugging Face Hub Settings", open=False):
519
- hf_token = gr.Textbox(
520
- label="Hugging Face Token (required for pushing to Hub)",
521
- type="password"
522
- )
523
-
524
- gr.Markdown("""### Choose when to push to Hub:
525
- 1. During Training: Model will be pushed automatically when training completes
526
- 2. After Training: You can push the trained model manually later""")
527
-
528
- # During Training Push
529
- with gr.Group():
530
- push_to_hub = gr.Checkbox(
531
- label="Push Model to Hub during training",
532
- value=False
533
- )
534
-
535
- with gr.Column(visible=False) as hub_settings:
536
- username = gr.Textbox(
537
- label="Hugging Face Username",
538
- placeholder="e.g., huggingface-username"
539
- )
540
- model_name = gr.Textbox(
541
- label="Model Name",
542
- placeholder="e.g., bert-complaint-classifier"
543
- )
544
-
545
- # Post-Training Push
546
- with gr.Group():
547
- post_train_push = gr.Checkbox(
548
- label="Push trained model to Hub after training",
549
- value=False
550
- )
551
-
552
- with gr.Column(visible=False) as post_train_settings:
553
- post_train_username = gr.Textbox(
554
- label="Hugging Face Username",
555
- placeholder="e.g., huggingface-username"
556
- )
557
- post_train_model_name = gr.Textbox(
558
- label="Model Name",
559
- placeholder="e.g., bert-complaint-classifier"
560
- )
561
- post_train_token = gr.Textbox(
562
- label="Hugging Face Token (if different from above)",
563
- type="password"
564
- )
565
- post_train_push_btn = gr.Button(
566
- "Push Model to Hub",
567
- variant="secondary"
568
- )
569
- post_train_status = gr.Textbox(label="Upload Status")
570
-
571
- # Show/hide settings based on checkboxes
572
- push_to_hub.change(
573
- lambda x: gr.update(visible=x),
574
- inputs=push_to_hub,
575
- outputs=hub_settings
576
- )
577
-
578
- post_train_push.change(
579
- lambda x: gr.update(visible=x),
580
- inputs=post_train_push,
581
- outputs=post_train_settings
582
- )
583
 
584
- gr.Markdown("### BERT Model Note")
585
- gr.Markdown("⚠️ BERT has a maximum sequence length of 512 tokens. Complaints longer than this will be truncated.")
 
 
 
586
 
587
- train_btn = gr.Button("Start Training", variant="primary")
588
- training_output = gr.Textbox(label="Training Progress", lines=10)
589
 
590
- # Connect the preview button
591
- preview_btn.click(
592
- preview_dataset,
593
- inputs=[dataset_file, text_column, label_column],
594
- outputs=dataset_preview
595
- )
596
 
597
- # Connect the training button
598
- train_btn.click(
599
- train_model,
600
- inputs=[
601
- dataset_file,
602
- text_column,
603
- label_column,
604
- num_epochs,
605
- batch_size,
606
- learning_rate,
607
- hf_token,
608
- push_to_hub,
609
- username,
610
- model_name
611
- ],
612
- outputs=training_output,
613
- show_progress="full"
614
- )
615
 
616
- # Connect the post-training push button
617
- post_train_push_btn.click(
618
- push_to_hub_after_training,
619
- inputs=[
620
- gr.Textbox(value=MODEL_PATH, visible=False),
621
- post_train_username,
622
- post_train_model_name,
623
- post_train_token
624
- ],
625
- outputs=post_train_status
626
- )
627
-
628
- # Classification Tab
629
- with gr.TabItem("Classify Complaints"):
630
- gr.Markdown("### Classify Customer Complaints")
631
 
632
- model_path = gr.Textbox(
633
- label="Model Path or Hugging Face ID",
634
- value="local-model",
635
- placeholder="e.g., local-model or your-username/bert-complaint-classifier"
636
- )
637
 
638
- with gr.Tabs():
639
- # Single Complaint Classification
640
- with gr.TabItem("Single Complaint"):
641
- text_input = gr.Textbox(
642
- label="Complaint Text",
643
- lines=5,
644
- placeholder="Enter a customer complaint here..."
645
- )
646
-
647
- classify_btn = gr.Button("Classify", variant="primary")
648
- token_info = gr.Markdown("Note: BERT has a 512 token limit. Longer complaints will be truncated.")
649
- text_output = gr.Textbox(label="Classification Result", lines=5)
650
-
651
- # Token counter
652
- def count_tokens(text):
653
- if not text or CURRENT_TOKENIZER is None:
654
- return "Enter text to see token count"
655
- tokens = CURRENT_TOKENIZER(text, truncation=False)
656
- count = len(tokens['input_ids'])
657
- if count > 512:
658
- return f"⚠️ **Token count: {count}/512** - Text will be truncated for BERT"
659
- else:
660
- return f"Token count: {count}/512"
661
-
662
- text_input.change(
663
- fn=count_tokens,
664
- inputs=text_input,
665
- outputs=token_info
666
- )
667
-
668
- classify_btn.click(
669
- predict_text,
670
- inputs=[text_input, model_path],
671
- outputs=text_output
672
- )
673
-
674
- # Batch Processing
675
- with gr.TabItem("Batch Processing"):
676
- gr.Markdown("Upload a CSV file with a 'complaint' column")
677
- csv_input = gr.File(label="Upload CSV", file_types=[".csv"])
678
- batch_classify_btn = gr.Button("Classify All", variant="primary")
679
- csv_output = gr.Textbox(label="Classification Results", lines=15)
680
-
681
- batch_classify_btn.click(
682
- predict_csv,
683
- inputs=[csv_input, model_path],
684
- outputs=csv_output
685
- )
686
 
687
  # Launch the app
688
  if __name__ == "__main__":
689
  # Initialize tokenizer on startup
690
  if CURRENT_TOKENIZER is None:
691
- CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
692
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import pandas as pd
4
  import os
 
 
 
5
  import json
6
+ import logging
7
+ import numpy as np
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from sklearn.metrics import accuracy_score, classification_report
11
+ from sklearn.model_selection import train_test_split
12
+
13
  from huggingface_hub import login, HfApi
14
+ from transformers import (
15
+ AutoTokenizer,
16
+ BertForSequenceClassification,
17
+ TrainingArguments,
18
+ Trainer,
19
+ DataCollatorWithPadding,
20
+ EarlyStoppingCallback
21
+ )
22
+ from datasets import Dataset, DatasetDict
23
+
24
+ # Set up logging
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
 
28
  # Global variables
29
  MODEL_PATH = "local-model"
 
35
  CURRENT_MODEL = None
36
  CURRENT_TOKENIZER = None
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
39
  """Load and prepare local CSV dataset for training"""
40
  try:
 
87
  raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
88
 
89
  # Create train/validation split
 
 
90
  train_df, val_df = train_test_split(
91
  df,
92
  test_size=test_size,
 
205
  except Exception as e:
206
  return f"❌ Failed to load model: {str(e)}"
207
 
208
+ def tokenize_function(examples, tokenizer, feature_column, max_length=512):
209
+ """Tokenize the input text"""
210
+ return tokenizer(
211
+ examples[feature_column],
212
+ truncation=True,
213
+ padding=False,
214
+ max_length=max_length
215
+ )
216
+
217
+ def compute_metrics(eval_pred):
218
+ """Compute metrics for evaluation"""
219
+ predictions, labels = eval_pred
220
+ predictions = np.argmax(predictions, axis=1)
221
+
222
+ accuracy = accuracy_score(labels, predictions)
223
+ report = classification_report(labels, predictions, output_dict=True, zero_division=0)
224
+
225
+ return {
226
+ 'accuracy': accuracy,
227
+ 'f1_macro': report['macro avg']['f1-score'],
228
+ 'f1_weighted': report['weighted avg']['f1-score'],
229
+ 'precision_macro': report['macro avg']['precision'],
230
+ 'recall_macro': report['macro avg']['recall']
231
+ }
232
+
233
+ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size,
234
+ learning_rate, hf_token, push_to_hub, username, model_name):
235
+ """Train the model using inline training (no subprocess)"""
236
+ global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER
237
+
238
+ TRAINING_LOGS = []
239
+
240
+ if hf_token:
241
+ login_result = login_to_hf(hf_token)
242
+ TRAINING_LOGS.append(login_result)
243
+ yield "\n".join(TRAINING_LOGS)
244
+
245
+ # Validate hub model ID if pushing to hub
246
+ if push_to_hub:
247
+ hub_model_id, error = validate_hub_model_id(username, model_name)
248
+ if error:
249
+ TRAINING_LOGS.append(f"❌ {error}")
250
+ yield "\n".join(TRAINING_LOGS)
251
+ return
252
+ else:
253
+ hub_model_id = None
254
+
255
+ # Validate uploaded file
256
+ if uploaded_file is None:
257
+ TRAINING_LOGS.append("❌ Please upload a dataset file")
258
+ yield "\n".join(TRAINING_LOGS)
259
+ return
260
+
261
+ dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
262
+
263
+ try:
264
+ # Load and prepare dataset
265
+ TRAINING_LOGS.append(f"📊 Loading dataset from uploaded file...")
266
+ yield "\n".join(TRAINING_LOGS)
267
+
268
+ dataset_dict, final_text_col, final_label_col = load_and_prepare_local_dataset(
269
+ dataset_file, text_column, label_column
270
+ )
271
+
272
+ TRAINING_LOGS.append(f"✅ Dataset loaded successfully!")
273
+ TRAINING_LOGS.append(f"- Train samples: {len(dataset_dict['train'])}")
274
+ TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
275
+ yield "\n".join(TRAINING_LOGS)
276
+
277
+ # Load model and tokenizer
278
+ TRAINING_LOGS.append("🤖 Loading BERT model and tokenizer...")
279
+ yield "\n".join(TRAINING_LOGS)
280
+
281
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
282
+ model = BertForSequenceClassification.from_pretrained(
283
+ "bert-base-uncased",
284
+ num_labels=len(CATEGORIES)
285
+ )
286
+
287
+ TRAINING_LOGS.append("✅ Model and tokenizer loaded")
288
+ yield "\n".join(TRAINING_LOGS)
289
+
290
+ # Tokenize datasets
291
+ TRAINING_LOGS.append("🔤 Tokenizing datasets...")
292
+ yield "\n".join(TRAINING_LOGS)
293
+
294
+ def tokenize_batch(examples):
295
+ return tokenize_function(examples, tokenizer, final_text_col, 512)
296
+
297
+ # Get columns to remove (keep only label column and tokenized features)
298
+ columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
299
+
300
+ tokenized_datasets = dataset_dict.map(
301
+ tokenize_batch,
302
+ batched=True,
303
+ remove_columns=columns_to_remove
304
+ )
305
+
306
+ # Rename label column to 'labels' (required by Trainer)
307
+ tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
308
+
309
+ TRAINING_LOGS.append("✅ Tokenization completed")
310
+ yield "\n".join(TRAINING_LOGS)
311
+
312
+ # Set up training
313
+ output_dir = Path(MODEL_PATH)
314
+ output_dir.mkdir(parents=True, exist_ok=True)
315
+
316
+ # Calculate steps
317
+ total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
318
+ eval_steps = max(10, min(100, total_steps // 4))
319
+ save_steps = max(20, min(500, total_steps // 2))
320
+ logging_steps = max(5, min(50, total_steps // 10))
321
+ warmup_steps = min(500, total_steps // 10)
322
+
323
+ TRAINING_LOGS.append(f"📈 Training configuration:")
324
+ TRAINING_LOGS.append(f"- Total steps: {total_steps}")
325
+ TRAINING_LOGS.append(f"- Eval steps: {eval_steps}")
326
+ TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
327
+ yield "\n".join(TRAINING_LOGS)
328
+
329
+ # Training arguments
330
+ training_args = TrainingArguments(
331
+ output_dir=str(output_dir),
332
+ num_train_epochs=num_epochs,
333
+ per_device_train_batch_size=batch_size,
334
+ per_device_eval_batch_size=batch_size,
335
+ warmup_steps=warmup_steps,
336
+ weight_decay=0.01,
337
+ learning_rate=learning_rate,
338
+ logging_dir=str(output_dir / "logs"),
339
+ logging_steps=logging_steps,
340
+ eval_strategy="steps",
341
+ eval_steps=eval_steps,
342
+ save_steps=save_steps,
343
+ save_total_limit=2,
344
+ load_best_model_at_end=True,
345
+ metric_for_best_model="eval_accuracy",
346
+ greater_is_better=True,
347
+ push_to_hub=push_to_hub,
348
+ hub_model_id=hub_model_id if push_to_hub else None,
349
+ report_to=None,
350
+ dataloader_num_workers=0,
351
+ fp16=torch.cuda.is_available(),
352
+ seed=42,
353
+ remove_unused_columns=False,
354
+ )
355
+
356
+ # Data collator
357
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
358
+
359
+ # Create trainer
360
+ trainer = Trainer(
361
+ model=model,
362
+ args=training_args,
363
+ train_dataset=tokenized_datasets['train'],
364
+ eval_dataset=tokenized_datasets['validation'],
365
+ tokenizer=tokenizer,
366
+ data_collator=data_collator,
367
+ compute_metrics=compute_metrics,
368
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
369
+ )
370
+
371
+ TRAINING_LOGS.append("🚀 Starting training...")
372
+ yield "\n".join(TRAINING_LOGS)
373
+
374
+ # Custom training loop with progress updates
375
+ class ProgressCallback:
376
+ def __init__(self, logs_list):
377
+ self.logs = logs_list
378
+ self.step_count = 0
379
+
380
+ def on_step_end(self, args, state, control, model=None, **kwargs):
381
+ self.step_count += 1
382
+ if self.step_count % logging_steps == 0:
383
+ self.logs.append(f"Step {self.step_count}/{total_steps}")
384
+
385
+ def on_epoch_end(self, args, state, control, model=None, **kwargs):
386
+ epoch = int(state.epoch)
387
+ self.logs.append(f"✅ Epoch {epoch} completed")
388
+
389
+ def on_evaluate(self, args, state, control, model=None, logs=None, **kwargs):
390
+ if logs:
391
+ acc = logs.get('eval_accuracy', 0)
392
+ loss = logs.get('eval_loss', 0)
393
+ self.logs.append(f"📊 Eval - Accuracy: {acc:.4f}, Loss: {loss:.4f}")
394
+
395
+ progress_callback = ProgressCallback(TRAINING_LOGS)
396
+ trainer.add_callback(progress_callback)
397
+
398
+ # Train the model
399
+ try:
400
+ trainer.train()
401
+ TRAINING_LOGS.append("✅ Training completed successfully!")
402
+ yield "\n".join(TRAINING_LOGS)
403
+ except Exception as e:
404
+ TRAINING_LOGS.append(f"❌ Training failed: {str(e)}")
405
+ yield "\n".join(TRAINING_LOGS)
406
+ return
407
+
408
+ # Save the model
409
+ TRAINING_LOGS.append("💾 Saving model...")
410
+ yield "\n".join(TRAINING_LOGS)
411
+
412
+ trainer.save_model()
413
+ tokenizer.save_pretrained(output_dir)
414
+
415
+ # Update global model and tokenizer
416
+ CURRENT_MODEL = model
417
+ CURRENT_TOKENIZER = tokenizer
418
+
419
+ TRAINING_LOGS.append("✅ Model saved successfully!")
420
+ yield "\n".join(TRAINING_LOGS)
421
+
422
+ # Final evaluation
423
+ TRAINING_LOGS.append("📊 Running final evaluation...")
424
+ yield "\n".join(TRAINING_LOGS)
425
+
426
+ try:
427
+ eval_results = trainer.evaluate()
428
+ TRAINING_LOGS.append("📊 Final Results:")
429
+ for key, value in eval_results.items():
430
+ if isinstance(value, float):
431
+ TRAINING_LOGS.append(f" {key}: {value:.4f}")
432
+ else:
433
+ TRAINING_LOGS.append(f" {key}: {value}")
434
+
435
+ # Save results
436
+ with open(output_dir / "eval_results.json", "w") as f:
437
+ json.dump(eval_results, f, indent=2)
438
+
439
+ except Exception as e:
440
+ TRAINING_LOGS.append(f"⚠️ Evaluation error: {str(e)}")
441
+
442
+ yield "\n".join(TRAINING_LOGS)
443
+
444
+ # Push to hub if requested
445
+ if push_to_hub and hub_model_id:
446
+ TRAINING_LOGS.append(f"🤗 Pushing to Hugging Face Hub: {hub_model_id}")
447
+ yield "\n".join(TRAINING_LOGS)
448
+
449
+ try:
450
+ trainer.push_to_hub()
451
+ TRAINING_LOGS.append(f"✅ Successfully pushed to {hub_model_id}")
452
+ except Exception as e:
453
+ TRAINING_LOGS.append(f"❌ Push to Hub failed: {str(e)}")
454
+
455
+ yield "\n".join(TRAINING_LOGS)
456
+
457
+ TRAINING_LOGS.append("\n✨ Training completed! Your model is ready to use.")
458
+ yield "\n".join(TRAINING_LOGS)
459
+
460
+ except Exception as e:
461
+ TRAINING_LOGS.append(f"❌ Error during training: {str(e)}")
462
+ yield "\n".join(TRAINING_LOGS)
463
+
464
  def predict_text(text, model_path):
465
  """Make a prediction on a single text input"""
466
  global CURRENT_MODEL, CURRENT_TOKENIZER
 
472
  return load_result
473
 
474
  try:
475
+ if not text.strip():
476
+ return "Please enter some text to classify."
477
+
478
+ # Check if text was truncated
479
+ original_tokens = CURRENT_TOKENIZER(text, truncation=False)
480
+ was_truncated = len(original_tokens['input_ids']) > 512
481
+
482
  # Tokenize input
483
  inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
484
 
485
  # Make prediction
486
  with torch.no_grad():
487
  outputs = CURRENT_MODEL(**inputs)
488
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
489
+ predicted_class_id = predictions.argmax().item()
490
+ confidence = predictions.max().item()
491
 
492
+ # Get predicted category
493
+ predicted_category = idx_to_category[predicted_class_id]
494
 
495
+ # Format result
 
 
496
  truncation_warning = "\n\n⚠️ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
497
 
498
+ result = []
499
+ result.append(f"**Complaint:** {text}")
500
+ result.append(f"\n**Predicted Category:** {predicted_category}")
501
+ result.append(f"**Confidence:** {confidence:.4f}")
502
+ result.append("\n**All Class Probabilities:**")
503
+
504
+ for i, category in enumerate(CATEGORIES):
505
+ prob = predictions[0][i].item()
506
+ result.append(f"- {category}: {prob:.4f}")
507
+
508
+ result.append(truncation_warning)
509
+
510
+ return "\n".join(result)
511
+
512
  except Exception as e:
513
+ return f"❌ Prediction error: {str(e)}"
514
 
515
  def predict_csv(csv_file, model_path):
516
  """Make predictions on a CSV file with complaints"""
 
533
  return "❌ CSV file must have a 'complaint' column"
534
 
535
  results = []
536
+ predictions_list = []
537
  truncated_count = 0
538
 
539
  for i, row in enumerate(df.iterrows()):
 
549
  inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
550
  with torch.no_grad():
551
  outputs = CURRENT_MODEL(**inputs)
552
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
553
+ predicted_idx = predictions.argmax().item()
554
+ confidence = predictions.max().item()
555
 
556
  predicted_category = idx_to_category[predicted_idx]
557
+ predictions_list.append({
558
+ 'complaint': complaint,
559
+ 'predicted_category': predicted_category,
560
+ 'confidence': confidence,
561
+ 'truncated': was_truncated
562
+ })
563
 
564
  truncation_mark = " ⚠️" if was_truncated else ""
565
  preview = complaint if len(complaint) <= 50 else complaint[:47] + "..."
566
+ results.append(f"Complaint {i+1}{truncation_mark}: {preview}")
567
+ results.append(f"Predicted: {predicted_category} (confidence: {confidence:.3f})\n")
568
 
569
  if i >= 19:
570
  results.append(f"... and {len(df) - 20} more (showing first 20 out of {len(df)} complaints)")
 
573
  if truncated_count > 0:
574
  results.append(f"\n⚠️ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
575
 
576
+ # Save full results to a CSV file
577
+ results_df = pd.DataFrame(predictions_list)
578
+ results_file = "prediction_results.csv"
579
+ results_df.to_csv(results_file, index=False)
580
+ results.append(f"\n💾 Full results saved to {results_file}")
581
+
582
  return "\n".join(results)
583
+
584
  except Exception as e:
585
  return f"❌ CSV processing failed: {str(e)}"
586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  def push_to_hub_after_training(model_path, username, model_name, token):
588
  """Push a trained model to Hugging Face Hub"""
589
  try:
 
616
  except Exception as e:
617
  return f"❌ Error: {str(e)}"
618
 
619
+ def count_tokens(text):
620
+ """Count tokens in input text"""
621
+ if not text or CURRENT_TOKENIZER is None:
622
+ return "Enter text to see token count"
623
+ tokens = CURRENT_TOKENIZER(text, truncation=False)
624
+ count = len(tokens['input_ids'])
625
+ if count > 512:
626
+ return f"⚠️ **Token count: {count}/512** - Text will be truncated for BERT"
627
+ else:
628
+ return f"Token count: {count}/512"
629
+
630
+ def get_available_datasets():
631
+ """Get list of available CSV files in the current directory"""
632
+ available_files = []
633
+ for file in os.listdir("."):
634
+ if file.endswith(".csv"):
635
+ try:
636
+ df = pd.read_csv(file)
637
+ available_files.append(f"{file} ({len(df)} rows)")
638
+ except:
639
+ available_files.append(f"{file} (Error reading)")
640
 
641
+ if not available_files:
642
+ available_files = ["No CSV files found in current directory"]
643
+
644
+ return available_files
645
+ def display_available_datasets():
646
+ datasets = get_available_datasets()
647
+ if datasets:
648
+ return "**Available CSV files:**\n\n" + "\n".join([f"- {file}" for file in datasets])
649
+ else:
650
+ return "No CSV files found in the current directory."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
 
652
+ # Initialize the display
653
+ refresh_datasets_btn.click(
654
+ display_available_datasets,
655
+ outputs=available_datasets
656
+ )
657
 
658
+ # Show datasets on load
659
+ app.load(display_available_datasets, outputs=available_datasets)
660
 
661
+ gr.Markdown("### Dataset Format Requirements")
662
+ gr.Markdown("""
663
+ **For training, your CSV file should have:**
664
+ - A text column containing the complaint text (default name: 'complaint')
665
+ - A label column containing categories (default name: 'category')
 
666
 
667
+ **Supported label formats:**
668
+ - Text labels: 'Online-Safety', 'BroadBand', 'TV-Radio'
669
+ - Numeric labels: 0, 1, 2 (corresponding to the categories above)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
 
671
+ **Example CSV structure:**
672
+ ```
673
+ complaint,category
674
+ "My internet is slow",BroadBand
675
+ "Blocked website access",Online-Safety
676
+ "Poor TV signal",TV-Radio
677
+ ```
678
+ """)
 
 
 
 
 
 
 
679
 
680
+ gr.Markdown("### Model Categories")
681
+ categories_info = f"""
682
+ **The model classifies complaints into these categories:**
 
 
683
 
684
+ | Index | Category | Description |
685
+ |-------|----------|-------------|
686
+ | 0 | Online-Safety | Internet safety, content filtering, cybersecurity issues |
687
+ | 1 | BroadBand | Internet connectivity, speed, network problems |
688
+ | 2 | TV-Radio | Television and radio broadcasting, signal quality issues |
689
+ """
690
+ gr.Markdown(categories_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
 
692
  # Launch the app
693
  if __name__ == "__main__":
694
  # Initialize tokenizer on startup
695
  if CURRENT_TOKENIZER is None:
696
+ try:
697
+ CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
698
+ print("✅ Tokenizer initialized successfully")
699
+ except Exception as e:
700
+ print(f"⚠️ Warning: Could not initialize tokenizer: {e}")
701
+
702
+ print("🚀 Launching BERT Complaint Classifier...")
703
+ print("📍 Available at: http://localhost:7860")
704
+
705
+ app.launch(
706
+ server_name="0.0.0.0",
707
+ server_port=7860,
708
+ share=False,
709
+ show_error=True
710
+ )
711
+