msmaje commited on
Commit
f3b6548
ยท
verified ยท
1 Parent(s): 2b7e143

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -93
app.py CHANGED
@@ -12,7 +12,7 @@ from sklearn.model_selection import train_test_split
12
 
13
  from huggingface_hub import login, HfApi
14
  from transformers import (
15
- AutoTokenizer,
16
  BertForSequenceClassification,
17
  TrainingArguments,
18
  Trainer,
@@ -35,16 +35,16 @@ TRAINING_LOGS = []
35
  CURRENT_MODEL = None
36
  CURRENT_TOKENIZER = None
37
 
38
- # --- Application Logic Functions (No change needed here, they are correctly indented) ---
39
-
40
  def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
41
  """Load and prepare local CSV dataset for training"""
42
  try:
43
  if not os.path.exists(file_path):
44
  raise FileNotFoundError(f"Dataset file not found: {file_path}")
45
 
 
46
  df = pd.read_csv(file_path)
47
 
 
48
  if text_column not in df.columns:
49
  available_cols = list(df.columns)
50
  raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}")
@@ -53,19 +53,24 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
53
  available_cols = list(df.columns)
54
  raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}")
55
 
 
56
  df = df.dropna(subset=[text_column, label_column])
57
  df[text_column] = df[text_column].astype(str)
58
 
 
59
  if df[label_column].dtype == 'object':
 
60
  unique_labels = df[label_column].unique()
61
  if len(unique_labels) > len(CATEGORIES):
62
  raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}")
63
 
 
64
  label_mapping = {}
65
  for label in unique_labels:
66
  if label in category_to_idx:
67
  label_mapping[label] = category_to_idx[label]
68
  else:
 
69
  available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values())
70
  if available_indices:
71
  label_mapping[label] = min(available_indices)
@@ -74,18 +79,22 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
74
 
75
  df['label_idx'] = df[label_column].map(label_mapping)
76
  else:
 
77
  df['label_idx'] = df[label_column].astype(int)
78
 
 
79
  if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES):
80
  raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
81
 
 
82
  train_df, val_df = train_test_split(
83
- df,
84
- test_size=test_size,
85
- random_state=42,
86
  stratify=df['label_idx']
87
  )
88
 
 
89
  train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']])
90
  val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']])
91
 
@@ -105,6 +114,7 @@ def preview_dataset(uploaded_file, text_column, label_column):
105
  if uploaded_file is None:
106
  return "Please upload a dataset file first."
107
 
 
108
  file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
109
 
110
  df = pd.read_csv(file_path)
@@ -152,9 +162,11 @@ def validate_hub_model_id(username, model_name):
152
  if not username or not model_name:
153
  return None, "Please provide both username and model name"
154
 
 
155
  model_name = model_name.strip().lower().replace(" ", "-")
156
  model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_'])
157
 
 
158
  hub_model_id = f"{username}/{model_name}"
159
 
160
  return hub_model_id, None
@@ -164,6 +176,7 @@ def load_model(model_path):
164
  global CURRENT_MODEL, CURRENT_TOKENIZER
165
 
166
  try:
 
167
  if os.path.exists(model_path):
168
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
169
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
@@ -172,6 +185,7 @@ def load_model(model_path):
172
  )
173
  return f"โœ… Model loaded from local path: {model_path}"
174
 
 
175
  try:
176
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
177
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
@@ -180,6 +194,7 @@ def load_model(model_path):
180
  )
181
  return f"โœ… Model loaded from Hugging Face Hub: {model_path}"
182
  except Exception as hub_error:
 
183
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
184
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
185
  "bert-base-uncased",
@@ -215,7 +230,7 @@ def compute_metrics(eval_pred):
215
  'recall_macro': report['macro avg']['recall']
216
  }
217
 
218
- def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size,
219
  learning_rate, hf_token, push_to_hub, username, model_name):
220
  """Train the model using inline training (no subprocess)"""
221
  global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER
@@ -227,6 +242,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
227
  TRAINING_LOGS.append(login_result)
228
  yield "\n".join(TRAINING_LOGS)
229
 
 
230
  if push_to_hub:
231
  hub_model_id, error = validate_hub_model_id(username, model_name)
232
  if error:
@@ -236,6 +252,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
236
  else:
237
  hub_model_id = None
238
 
 
239
  if uploaded_file is None:
240
  TRAINING_LOGS.append("โŒ Please upload a dataset file")
241
  yield "\n".join(TRAINING_LOGS)
@@ -244,6 +261,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
244
  dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
245
 
246
  try:
 
247
  TRAINING_LOGS.append(f"๐Ÿ“Š Loading dataset from uploaded file...")
248
  yield "\n".join(TRAINING_LOGS)
249
 
@@ -256,6 +274,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
256
  TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
257
  yield "\n".join(TRAINING_LOGS)
258
 
 
259
  TRAINING_LOGS.append("๐Ÿค– Loading BERT model and tokenizer...")
260
  yield "\n".join(TRAINING_LOGS)
261
 
@@ -268,12 +287,14 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
268
  TRAINING_LOGS.append("โœ… Model and tokenizer loaded")
269
  yield "\n".join(TRAINING_LOGS)
270
 
 
271
  TRAINING_LOGS.append("๐Ÿ”ค Tokenizing datasets...")
272
  yield "\n".join(TRAINING_LOGS)
273
 
274
  def tokenize_batch(examples):
275
  return tokenize_function(examples, tokenizer, final_text_col, 512)
276
 
 
277
  columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
278
 
279
  tokenized_datasets = dataset_dict.map(
@@ -282,14 +303,17 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
282
  remove_columns=columns_to_remove
283
  )
284
 
 
285
  tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
286
 
287
  TRAINING_LOGS.append("โœ… Tokenization completed")
288
  yield "\n".join(TRAINING_LOGS)
289
 
 
290
  output_dir = Path(MODEL_PATH)
291
  output_dir.mkdir(parents=True, exist_ok=True)
292
 
 
293
  total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
294
  eval_steps = max(10, min(100, total_steps // 4))
295
  save_steps = max(20, min(500, total_steps // 2))
@@ -302,6 +326,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
302
  TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
303
  yield "\n".join(TRAINING_LOGS)
304
 
 
305
  training_args = TrainingArguments(
306
  output_dir=str(output_dir),
307
  num_train_epochs=num_epochs,
@@ -328,8 +353,10 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
328
  remove_unused_columns=False,
329
  )
330
 
 
331
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
332
 
 
333
  trainer = Trainer(
334
  model=model,
335
  args=training_args,
@@ -344,6 +371,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
344
  TRAINING_LOGS.append("๐Ÿš€ Starting training...")
345
  yield "\n".join(TRAINING_LOGS)
346
 
 
347
  class ProgressCallback:
348
  def __init__(self, logs_list):
349
  self.logs = logs_list
@@ -367,6 +395,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
367
  progress_callback = ProgressCallback(TRAINING_LOGS)
368
  trainer.add_callback(progress_callback)
369
 
 
370
  try:
371
  trainer.train()
372
  TRAINING_LOGS.append("โœ… Training completed successfully!")
@@ -376,18 +405,21 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
376
  yield "\n".join(TRAINING_LOGS)
377
  return
378
 
 
379
  TRAINING_LOGS.append("๐Ÿ’พ Saving model...")
380
  yield "\n".join(TRAINING_LOGS)
381
 
382
  trainer.save_model()
383
  tokenizer.save_pretrained(output_dir)
384
 
 
385
  CURRENT_MODEL = model
386
  CURRENT_TOKENIZER = tokenizer
387
 
388
  TRAINING_LOGS.append("โœ… Model saved successfully!")
389
  yield "\n".join(TRAINING_LOGS)
390
 
 
391
  TRAINING_LOGS.append("๐Ÿ“Š Running final evaluation...")
392
  yield "\n".join(TRAINING_LOGS)
393
 
@@ -400,6 +432,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
400
  else:
401
  TRAINING_LOGS.append(f" {key}: {value}")
402
 
 
403
  with open(output_dir / "eval_results.json", "w") as f:
404
  json.dump(eval_results, f, indent=2)
405
 
@@ -408,6 +441,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
408
 
409
  yield "\n".join(TRAINING_LOGS)
410
 
 
411
  if push_to_hub and hub_model_id:
412
  TRAINING_LOGS.append(f"๐Ÿค— Pushing to Hugging Face Hub: {hub_model_id}")
413
  yield "\n".join(TRAINING_LOGS)
@@ -431,6 +465,7 @@ def predict_text(text, model_path):
431
  """Make a prediction on a single text input"""
432
  global CURRENT_MODEL, CURRENT_TOKENIZER
433
 
 
434
  if CURRENT_MODEL is None or model_path != MODEL_PATH:
435
  load_result = load_model(model_path)
436
  if load_result.startswith("โŒ"):
@@ -440,19 +475,24 @@ def predict_text(text, model_path):
440
  if not text.strip():
441
  return "Please enter some text to classify."
442
 
 
443
  original_tokens = CURRENT_TOKENIZER(text, truncation=False)
444
  was_truncated = len(original_tokens['input_ids']) > 512
445
 
 
446
  inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
447
 
 
448
  with torch.no_grad():
449
  outputs = CURRENT_MODEL(**inputs)
450
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
451
  predicted_class_id = predictions.argmax().item()
452
  confidence = predictions.max().item()
453
 
 
454
  predicted_category = idx_to_category[predicted_class_id]
455
 
 
456
  truncation_warning = "\n\nโš ๏ธ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
457
 
458
  result = []
@@ -476,12 +516,14 @@ def predict_csv(csv_file, model_path):
476
  """Make predictions on a CSV file with complaints"""
477
  global CURRENT_MODEL, CURRENT_TOKENIZER
478
 
 
479
  if CURRENT_MODEL is None or model_path != MODEL_PATH:
480
  load_result = load_model(model_path)
481
  if load_result.startswith("โŒ"):
482
  return load_result, None
483
 
484
  try:
 
485
  if hasattr(csv_file, 'name'):
486
  df = pd.read_csv(csv_file.name)
487
  else:
@@ -497,11 +539,13 @@ def predict_csv(csv_file, model_path):
497
  for i, row in enumerate(df.iterrows()):
498
  complaint = str(row[1]['complaint'])
499
 
 
500
  original_tokens = CURRENT_TOKENIZER(complaint, truncation=False)
501
  was_truncated = len(original_tokens['input_ids']) > 512
502
  if was_truncated:
503
  truncated_count += 1
504
 
 
505
  inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
506
  with torch.no_grad():
507
  outputs = CURRENT_MODEL(**inputs)
@@ -529,6 +573,7 @@ def predict_csv(csv_file, model_path):
529
  if truncated_count > 0:
530
  results.append(f"\nโš ๏ธ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
531
 
 
532
  results_df = pd.DataFrame(predictions_list)
533
  results_file = "prediction_results.csv"
534
  results_df.to_csv(results_file, index=False)
@@ -549,6 +594,7 @@ def push_to_hub_after_training(model_path, username, model_name, token):
549
  if error:
550
  return f"โŒ {error}"
551
 
 
552
  login(token)
553
  if not os.path.exists(model_path):
554
  return "โŒ No trained model found. Please train a model first."
@@ -559,6 +605,7 @@ def push_to_hub_after_training(model_path, username, model_name, token):
559
  except Exception as e:
560
  return f"โŒ Failed to load model: {str(e)}"
561
 
 
562
  try:
563
  model.push_to_hub(hub_model_id)
564
  tokenizer.push_to_hub(hub_model_id)
@@ -603,8 +650,6 @@ def display_available_datasets():
603
  else:
604
  return "No CSV files found in the current directory."
605
 
606
- # --- Gradio UI Definition (Correctly structured) ---
607
-
608
  # Initialize tokenizer on startup
609
  if CURRENT_TOKENIZER is None:
610
  try:
@@ -616,7 +661,7 @@ if CURRENT_TOKENIZER is None:
616
  print("๐Ÿš€ Launching BERT Complaint Classifier...")
617
  print("๐Ÿ“ Available at: http://localhost:7860")
618
 
619
- # The entire Gradio UI definition must be within this single block
620
  with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app:
621
  gr.Markdown("# BERT Complaint Classifier ๐Ÿ—ฃ๏ธ๐Ÿค–")
622
  gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.")
@@ -666,98 +711,86 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
666
  with gr.Column(variant="panel"):
667
  gr.Markdown("### Classify a Single Complaint")
668
 
669
- model_path_input = gr.Textbox(
670
- label="Model Path or Hub ID",
671
- value="bert-base-uncased",
672
- placeholder="e.g., local-model or your_username/your_model"
673
- )
674
 
675
- with gr.Row():
676
- text_input = gr.Textbox(label="Complaint Text", lines=3)
677
- token_count_output = gr.Markdown("Token count: 0/512")
678
 
679
- predict_btn = gr.Button("Classify Complaint", variant="primary")
680
- single_prediction_output = gr.Markdown("Prediction will appear here...")
 
 
 
 
681
 
682
- text_input.change(count_tokens, inputs=text_input, outputs=token_count_output)
683
-
684
- with gr.Tab("Predict from CSV"):
 
 
685
  with gr.Column(variant="panel"):
686
- gr.Markdown("### Classify Complaints from a CSV File")
687
- csv_file_input = gr.File(label="Upload CSV File (with 'complaint' column)")
688
- csv_model_path = gr.Textbox(
689
- label="Model Path or Hub ID",
690
- value="local-model",
691
- placeholder="e.g., local-model or your_username/your_model"
692
- )
693
- csv_predict_btn = gr.Button("Run Predictions on CSV", variant="primary")
694
- csv_prediction_output = gr.Markdown("Predictions will appear here...")
695
- download_link = gr.File(label="Download Full Predictions", interactive=False)
696
 
697
- predict_btn.click(
698
- predict_text,
699
- inputs=[text_input, model_path_input],
700
- outputs=single_prediction_output
701
- )
702
- csv_predict_btn.click(
703
- predict_csv,
704
- inputs=[csv_file_input, csv_model_path],
705
- outputs=[csv_prediction_output, download_link]
706
- )
 
 
 
 
 
707
 
708
- with gr.Tab("Tools"):
709
- gr.Markdown("## ๐Ÿ”ง Tools")
710
- gr.Markdown("Utilities for managing datasets and models.")
 
 
 
 
 
 
711
 
712
- with gr.Accordion("Dataset Information"):
713
- available_datasets = gr.Markdown("No CSV files found in the current directory.")
714
- refresh_datasets_btn = gr.Button("๐Ÿ”„ Refresh Available Datasets")
715
- gr.Markdown("### Dataset Format Requirements")
716
- gr.Markdown("""
717
- **For training, your CSV file should have:**
718
- - A text column containing the complaint text (default name: 'complaint')
719
- - A label column containing categories (default name: 'category')
 
720
 
721
- **Supported label formats:**
722
- - Text labels: 'Online-Safety', 'BroadBand', 'TV-Radio'
723
- - Numeric labels: 0, 1, 2 (corresponding to the categories above)
724
 
725
- **Example CSV structure:**
726
- ```
727
- complaint,category
728
- "My internet is slow",BroadBand
729
- "Blocked website access",Online-Safety
730
- "Poor TV signal",TV-Radio
731
- ```
732
- """)
733
- gr.Markdown("### Model Categories")
734
- categories_info = f"""
735
- **The model classifies complaints into these categories:**
736
 
737
- | Index | Category | Description |
738
- |-------|----------|-------------|
739
- | 0 | Online-Safety | Internet safety, content filtering, cybersecurity issues |
740
- | 1 | BroadBand | Internet connectivity, speed, network problems |
741
- | 2 | TV-Radio | Television and radio broadcasting, signal quality issues |
742
- """
743
- gr.Markdown(categories_info)
744
-
745
- with gr.Accordion("Push Local Model to Hub"):
746
- gr.Markdown("Use this to manually push a locally trained model (`./local-model`) to the Hub.")
747
- with gr.Row():
748
- hub_username_input_push = gr.Textbox(label="Hugging Face Username")
749
- hub_model_name_input_push = gr.Textbox(label="Model Name")
750
- hub_token_input_push = gr.Textbox(label="Hugging Face Token", type="password")
751
 
752
- push_btn = gr.Button("๐Ÿš€ Push Model to Hub", variant="primary")
753
- push_output = gr.Textbox(label="Results", lines=3, interactive=False)
 
 
 
 
754
 
755
- push_btn.click(
756
- push_to_hub_after_training,
757
- inputs=[gr.Textbox(value=MODEL_PATH, visible=False), hub_username_input_push, hub_model_name_input_push, hub_token_input_push],
758
- outputs=push_output
759
- )
760
-
761
  preview_btn.click(
762
  preview_dataset,
763
  inputs=[uploaded_file, text_column_input, label_column_input],
@@ -781,13 +814,51 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
781
  outputs=training_log_output,
782
  )
783
 
784
- refresh_datasets_btn.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785
  display_available_datasets,
786
- outputs=available_datasets
787
  )
788
 
789
- app.load(display_available_datasets, outputs=available_datasets)
 
790
 
 
791
  if __name__ == "__main__":
792
  app.launch(
793
  server_name="0.0.0.0",
 
12
 
13
  from huggingface_hub import login, HfApi
14
  from transformers import (
15
+ AutoTokenizer,
16
  BertForSequenceClassification,
17
  TrainingArguments,
18
  Trainer,
 
35
  CURRENT_MODEL = None
36
  CURRENT_TOKENIZER = None
37
 
 
 
38
  def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
39
  """Load and prepare local CSV dataset for training"""
40
  try:
41
  if not os.path.exists(file_path):
42
  raise FileNotFoundError(f"Dataset file not found: {file_path}")
43
 
44
+ # Load the CSV file
45
  df = pd.read_csv(file_path)
46
 
47
+ # Verify required columns exist
48
  if text_column not in df.columns:
49
  available_cols = list(df.columns)
50
  raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}")
 
53
  available_cols = list(df.columns)
54
  raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}")
55
 
56
+ # Clean the data
57
  df = df.dropna(subset=[text_column, label_column])
58
  df[text_column] = df[text_column].astype(str)
59
 
60
+ # Handle different label formats
61
  if df[label_column].dtype == 'object':
62
+ # If labels are text, convert to indices
63
  unique_labels = df[label_column].unique()
64
  if len(unique_labels) > len(CATEGORIES):
65
  raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}")
66
 
67
+ # Try to map text labels to our categories
68
  label_mapping = {}
69
  for label in unique_labels:
70
  if label in category_to_idx:
71
  label_mapping[label] = category_to_idx[label]
72
  else:
73
+ # Auto-assign if not found
74
  available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values())
75
  if available_indices:
76
  label_mapping[label] = min(available_indices)
 
79
 
80
  df['label_idx'] = df[label_column].map(label_mapping)
81
  else:
82
+ # If labels are already numeric
83
  df['label_idx'] = df[label_column].astype(int)
84
 
85
+ # Verify label indices are valid
86
  if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES):
87
  raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
88
 
89
+ # Create train/validation split
90
  train_df, val_df = train_test_split(
91
+ df,
92
+ test_size=test_size,
93
+ random_state=42,
94
  stratify=df['label_idx']
95
  )
96
 
97
+ # Convert to Hugging Face datasets
98
  train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']])
99
  val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']])
100
 
 
114
  if uploaded_file is None:
115
  return "Please upload a dataset file first."
116
 
117
+ # Get the file path from the uploaded file
118
  file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
119
 
120
  df = pd.read_csv(file_path)
 
162
  if not username or not model_name:
163
  return None, "Please provide both username and model name"
164
 
165
+ # Clean up the model name
166
  model_name = model_name.strip().lower().replace(" ", "-")
167
  model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_'])
168
 
169
+ # Construct the full model ID
170
  hub_model_id = f"{username}/{model_name}"
171
 
172
  return hub_model_id, None
 
176
  global CURRENT_MODEL, CURRENT_TOKENIZER
177
 
178
  try:
179
+ # Try loading from local path first
180
  if os.path.exists(model_path):
181
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
182
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
 
185
  )
186
  return f"โœ… Model loaded from local path: {model_path}"
187
 
188
+ # If local path doesn't exist, try loading from Hub
189
  try:
190
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
191
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
 
194
  )
195
  return f"โœ… Model loaded from Hugging Face Hub: {model_path}"
196
  except Exception as hub_error:
197
+ # If both local and hub loading fail, fall back to base model
198
  CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
199
  CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
200
  "bert-base-uncased",
 
230
  'recall_macro': report['macro avg']['recall']
231
  }
232
 
233
+ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size,
234
  learning_rate, hf_token, push_to_hub, username, model_name):
235
  """Train the model using inline training (no subprocess)"""
236
  global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER
 
242
  TRAINING_LOGS.append(login_result)
243
  yield "\n".join(TRAINING_LOGS)
244
 
245
+ # Validate hub model ID if pushing to hub
246
  if push_to_hub:
247
  hub_model_id, error = validate_hub_model_id(username, model_name)
248
  if error:
 
252
  else:
253
  hub_model_id = None
254
 
255
+ # Validate uploaded file
256
  if uploaded_file is None:
257
  TRAINING_LOGS.append("โŒ Please upload a dataset file")
258
  yield "\n".join(TRAINING_LOGS)
 
261
  dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
262
 
263
  try:
264
+ # Load and prepare dataset
265
  TRAINING_LOGS.append(f"๐Ÿ“Š Loading dataset from uploaded file...")
266
  yield "\n".join(TRAINING_LOGS)
267
 
 
274
  TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
275
  yield "\n".join(TRAINING_LOGS)
276
 
277
+ # Load model and tokenizer
278
  TRAINING_LOGS.append("๐Ÿค– Loading BERT model and tokenizer...")
279
  yield "\n".join(TRAINING_LOGS)
280
 
 
287
  TRAINING_LOGS.append("โœ… Model and tokenizer loaded")
288
  yield "\n".join(TRAINING_LOGS)
289
 
290
+ # Tokenize datasets
291
  TRAINING_LOGS.append("๐Ÿ”ค Tokenizing datasets...")
292
  yield "\n".join(TRAINING_LOGS)
293
 
294
  def tokenize_batch(examples):
295
  return tokenize_function(examples, tokenizer, final_text_col, 512)
296
 
297
+ # Get columns to remove (keep only label column and tokenized features)
298
  columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
299
 
300
  tokenized_datasets = dataset_dict.map(
 
303
  remove_columns=columns_to_remove
304
  )
305
 
306
+ # Rename label column to 'labels' (required by Trainer)
307
  tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
308
 
309
  TRAINING_LOGS.append("โœ… Tokenization completed")
310
  yield "\n".join(TRAINING_LOGS)
311
 
312
+ # Set up training
313
  output_dir = Path(MODEL_PATH)
314
  output_dir.mkdir(parents=True, exist_ok=True)
315
 
316
+ # Calculate steps
317
  total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
318
  eval_steps = max(10, min(100, total_steps // 4))
319
  save_steps = max(20, min(500, total_steps // 2))
 
326
  TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
327
  yield "\n".join(TRAINING_LOGS)
328
 
329
+ # Training arguments
330
  training_args = TrainingArguments(
331
  output_dir=str(output_dir),
332
  num_train_epochs=num_epochs,
 
353
  remove_unused_columns=False,
354
  )
355
 
356
+ # Data collator
357
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
358
 
359
+ # Create trainer
360
  trainer = Trainer(
361
  model=model,
362
  args=training_args,
 
371
  TRAINING_LOGS.append("๐Ÿš€ Starting training...")
372
  yield "\n".join(TRAINING_LOGS)
373
 
374
+ # Custom training loop with progress updates
375
  class ProgressCallback:
376
  def __init__(self, logs_list):
377
  self.logs = logs_list
 
395
  progress_callback = ProgressCallback(TRAINING_LOGS)
396
  trainer.add_callback(progress_callback)
397
 
398
+ # Train the model
399
  try:
400
  trainer.train()
401
  TRAINING_LOGS.append("โœ… Training completed successfully!")
 
405
  yield "\n".join(TRAINING_LOGS)
406
  return
407
 
408
+ # Save the model
409
  TRAINING_LOGS.append("๐Ÿ’พ Saving model...")
410
  yield "\n".join(TRAINING_LOGS)
411
 
412
  trainer.save_model()
413
  tokenizer.save_pretrained(output_dir)
414
 
415
+ # Update global model and tokenizer
416
  CURRENT_MODEL = model
417
  CURRENT_TOKENIZER = tokenizer
418
 
419
  TRAINING_LOGS.append("โœ… Model saved successfully!")
420
  yield "\n".join(TRAINING_LOGS)
421
 
422
+ # Final evaluation
423
  TRAINING_LOGS.append("๐Ÿ“Š Running final evaluation...")
424
  yield "\n".join(TRAINING_LOGS)
425
 
 
432
  else:
433
  TRAINING_LOGS.append(f" {key}: {value}")
434
 
435
+ # Save results
436
  with open(output_dir / "eval_results.json", "w") as f:
437
  json.dump(eval_results, f, indent=2)
438
 
 
441
 
442
  yield "\n".join(TRAINING_LOGS)
443
 
444
+ # Push to hub if requested
445
  if push_to_hub and hub_model_id:
446
  TRAINING_LOGS.append(f"๐Ÿค— Pushing to Hugging Face Hub: {hub_model_id}")
447
  yield "\n".join(TRAINING_LOGS)
 
465
  """Make a prediction on a single text input"""
466
  global CURRENT_MODEL, CURRENT_TOKENIZER
467
 
468
+ # Load the model if it's not loaded or a different one is requested
469
  if CURRENT_MODEL is None or model_path != MODEL_PATH:
470
  load_result = load_model(model_path)
471
  if load_result.startswith("โŒ"):
 
475
  if not text.strip():
476
  return "Please enter some text to classify."
477
 
478
+ # Check if text was truncated
479
  original_tokens = CURRENT_TOKENIZER(text, truncation=False)
480
  was_truncated = len(original_tokens['input_ids']) > 512
481
 
482
+ # Tokenize input
483
  inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
484
 
485
+ # Make prediction
486
  with torch.no_grad():
487
  outputs = CURRENT_MODEL(**inputs)
488
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
489
  predicted_class_id = predictions.argmax().item()
490
  confidence = predictions.max().item()
491
 
492
+ # Get predicted category
493
  predicted_category = idx_to_category[predicted_class_id]
494
 
495
+ # Format result
496
  truncation_warning = "\n\nโš ๏ธ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
497
 
498
  result = []
 
516
  """Make predictions on a CSV file with complaints"""
517
  global CURRENT_MODEL, CURRENT_TOKENIZER
518
 
519
+ # Load the model if needed
520
  if CURRENT_MODEL is None or model_path != MODEL_PATH:
521
  load_result = load_model(model_path)
522
  if load_result.startswith("โŒ"):
523
  return load_result, None
524
 
525
  try:
526
+ # Read the CSV file
527
  if hasattr(csv_file, 'name'):
528
  df = pd.read_csv(csv_file.name)
529
  else:
 
539
  for i, row in enumerate(df.iterrows()):
540
  complaint = str(row[1]['complaint'])
541
 
542
+ # Check for truncation
543
  original_tokens = CURRENT_TOKENIZER(complaint, truncation=False)
544
  was_truncated = len(original_tokens['input_ids']) > 512
545
  if was_truncated:
546
  truncated_count += 1
547
 
548
+ # Predict
549
  inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
550
  with torch.no_grad():
551
  outputs = CURRENT_MODEL(**inputs)
 
573
  if truncated_count > 0:
574
  results.append(f"\nโš ๏ธ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
575
 
576
+ # Save full results to a CSV file
577
  results_df = pd.DataFrame(predictions_list)
578
  results_file = "prediction_results.csv"
579
  results_df.to_csv(results_file, index=False)
 
594
  if error:
595
  return f"โŒ {error}"
596
 
597
+ # Login and load model
598
  login(token)
599
  if not os.path.exists(model_path):
600
  return "โŒ No trained model found. Please train a model first."
 
605
  except Exception as e:
606
  return f"โŒ Failed to load model: {str(e)}"
607
 
608
+ # Push to Hub
609
  try:
610
  model.push_to_hub(hub_model_id)
611
  tokenizer.push_to_hub(hub_model_id)
 
650
  else:
651
  return "No CSV files found in the current directory."
652
 
 
 
653
  # Initialize tokenizer on startup
654
  if CURRENT_TOKENIZER is None:
655
  try:
 
661
  print("๐Ÿš€ Launching BERT Complaint Classifier...")
662
  print("๐Ÿ“ Available at: http://localhost:7860")
663
 
664
+ # The entire Gradio UI definition must be within a single block
665
  with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app:
666
  gr.Markdown("# BERT Complaint Classifier ๐Ÿ—ฃ๏ธ๐Ÿค–")
667
  gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.")
 
711
  with gr.Column(variant="panel"):
712
  gr.Markdown("### Classify a Single Complaint")
713
 
714
+ model_path_input = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID")
715
+ load_model_btn = gr.Button("Load Model")
716
+ model_status = gr.Textbox(label="Model Status", interactive=False)
 
 
717
 
718
+ gr.Markdown("---")
 
 
719
 
720
+ text_input = gr.Textbox(
721
+ label="Enter complaint text",
722
+ lines=3,
723
+ placeholder="Type your complaint here..."
724
+ )
725
+ token_counter = gr.Textbox(label="Token Count", interactive=False, value="Enter text to see token count")
726
 
727
+ predict_btn = gr.Button("๐Ÿ”ฎ Predict Category", variant="primary")
728
+
729
+ prediction_output = gr.Markdown("Prediction results will appear here")
730
+
731
+ with gr.Tab("Predict CSV File"):
732
  with gr.Column(variant="panel"):
733
+ gr.Markdown("### Classify Multiple Complaints from CSV")
734
+ gr.Markdown("Upload a CSV file with a 'complaint' column to classify multiple complaints at once.")
 
 
 
 
 
 
 
 
735
 
736
+ csv_model_path = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID")
737
+ csv_load_btn = gr.Button("Load Model")
738
+ csv_model_status = gr.Textbox(label="Model Status", interactive=False)
739
+
740
+ gr.Markdown("---")
741
+
742
+ csv_file_input = gr.File(label="Upload CSV File", type="filepath", file_types=["csv"])
743
+ csv_predict_btn = gr.Button("๐Ÿ”ฎ Predict All", variant="primary")
744
+
745
+ csv_prediction_output = gr.Markdown("CSV prediction results will appear here")
746
+ csv_download = gr.File(label="Download Results", interactive=False)
747
+
748
+ with gr.Tab("Push to Hub"):
749
+ gr.Markdown("## ๐Ÿค— Push Trained Model to Hugging Face Hub")
750
+ gr.Markdown("Upload your locally trained model to the Hugging Face Hub for sharing.")
751
 
752
+ with gr.Column(variant="panel"):
753
+ hub_model_path = gr.Textbox(label="Local Model Path", value=MODEL_PATH)
754
+ hub_username = gr.Textbox(label="Hugging Face Username")
755
+ hub_model_name = gr.Textbox(label="Model Name", value="bert-complaint-classifier")
756
+ hub_token = gr.Textbox(label="Hugging Face Token", type="password")
757
+
758
+ push_hub_btn = gr.Button("๐Ÿš€ Push to Hub", variant="primary")
759
+
760
+ push_hub_output = gr.Markdown("Push results will appear here")
761
 
762
+ with gr.Tab("Dataset Info"):
763
+ gr.Markdown("## ๐Ÿ“Š Dataset Information")
764
+ gr.Markdown("View information about available datasets and model categories.")
765
+
766
+ with gr.Column(variant="panel"):
767
+ gr.Markdown("### ๐ŸŽฏ Model Categories")
768
+ categories_info = gr.Markdown(f"**Available Categories:**\n\n" + "\n".join([f"- **{cat}** (index: {idx})" for idx, cat in idx_to_category.items()]))
769
+
770
+ gr.Markdown("---")
771
 
772
+ gr.Markdown("### ๐Ÿ“ Available Datasets")
773
+ datasets_btn = gr.Button("๐Ÿ” Scan for CSV Files")
774
+ datasets_info = gr.Markdown("Click 'Scan for CSV Files' to see available datasets")
775
 
776
+ gr.Markdown("---")
 
 
 
 
 
 
 
 
 
 
777
 
778
+ gr.Markdown("### ๐Ÿ’ก Tips")
779
+ gr.Markdown("""
780
+ **Dataset Format:**
781
+ - CSV file with at least two columns
782
+ - One column for text (complaints)
783
+ - One column for labels/categories
784
+ - Labels can be text (will be auto-mapped) or numeric indices (0, 1, 2)
 
 
 
 
 
 
 
785
 
786
+ **Training Tips:**
787
+ - Start with 3 epochs and adjust based on results
788
+ - Use batch size 8-16 for most datasets
789
+ - Learning rate 2e-5 works well for BERT fine-tuning
790
+ - Enable early stopping to prevent overfitting
791
+ """)
792
 
793
+ # Connect functions to UI components
 
 
 
 
 
794
  preview_btn.click(
795
  preview_dataset,
796
  inputs=[uploaded_file, text_column_input, label_column_input],
 
814
  outputs=training_log_output,
815
  )
816
 
817
+ load_model_btn.click(
818
+ load_model,
819
+ inputs=model_path_input,
820
+ outputs=model_status
821
+ )
822
+
823
+ predict_btn.click(
824
+ predict_text,
825
+ inputs=[text_input, model_path_input],
826
+ outputs=prediction_output
827
+ )
828
+
829
+ text_input.change(
830
+ count_tokens,
831
+ inputs=text_input,
832
+ outputs=token_counter
833
+ )
834
+
835
+ csv_load_btn.click(
836
+ load_model,
837
+ inputs=csv_model_path,
838
+ outputs=csv_model_status
839
+ )
840
+
841
+ csv_predict_btn.click(
842
+ predict_csv,
843
+ inputs=[csv_file_input, csv_model_path],
844
+ outputs=[csv_prediction_output, csv_download]
845
+ )
846
+
847
+ push_hub_btn.click(
848
+ push_to_hub_after_training,
849
+ inputs=[hub_model_path, hub_username, hub_model_name, hub_token],
850
+ outputs=push_hub_output
851
+ )
852
+
853
+ datasets_btn.click(
854
  display_available_datasets,
855
+ outputs=datasets_info
856
  )
857
 
858
+ # Run a check for available datasets on app load
859
+ app.load(display_available_datasets, outputs=datasets_info)
860
 
861
+ # Launch the Gradio app
862
  if __name__ == "__main__":
863
  app.launch(
864
  server_name="0.0.0.0",