Spaces:
Sleeping
Sleeping
File size: 35,703 Bytes
4e7455f 04f0e6e d58a542 1a06556 d58a542 f3b6548 d58a542 1a06556 d58a542 4e7455f 04f0e6e 4e7455f 04f0e6e f3b6548 04f0e6e f3b6548 04f0e6e f3b6548 04f0e6e f3b6548 04f0e6e f3b6548 04f0e6e f3b6548 04f0e6e f3b6548 04f0e6e f3b6548 04f0e6e f3b6548 04f0e6e f3b6548 04f0e6e f3b6548 e2aef65 f3b6548 04f0e6e f3b6548 04f0e6e dfd51d1 04f0e6e dfd51d1 04f0e6e f3b6548 dfd51d1 04f0e6e 7f71857 04f0e6e dfd51d1 04f0e6e dfd51d1 4e7455f f3b6548 4e7455f f3b6548 4e7455f f3b6548 4e7455f f3b6548 4e7455f f3b6548 4e7455f d58a542 f3b6548 3c14fdc d58a542 f3b6548 d58a542 f3b6548 d58a542 f3b6548 d58a542 f3b6548 d58a542 f3b6548 d58a542 f3b6548 d58a542 f3b6548 d58a542 f3b6548 d58a542 f3b6548 d58a542 94875b9 5b47a77 94875b9 3bd188c f3b6548 d58a542 1a06556 d58a542 f3b6548 d58a542 1a06556 d58a542 f3b6548 d58a542 f3b6548 d58a542 f3b6548 d58a542 f3b6548 d58a542 f3b6548 d58a542 f3b6548 d58a542 4e7455f f3b6548 4e7455f d58a542 f3b6548 d58a542 f3b6548 4e7455f f3b6548 4e7455f d58a542 4e7455f f3b6548 d58a542 4e7455f f3b6548 4e7455f d58a542 4e7455f d58a542 4e7455f f3b6548 4e7455f 2b7e143 4e7455f f3b6548 4e7455f 2b7e143 4e7455f d58a542 4e7455f f3b6548 4e7455f f3b6548 4e7455f d58a542 4e7455f d58a542 4e7455f d58a542 4e7455f f3b6548 d58a542 2b7e143 d58a542 4e7455f 2b7e143 4e7455f f3b6548 4e7455f f3b6548 4e7455f d58a542 94875b9 d58a542 e2aef65 d58a542 4e7455f d58a542 3c14fdc d58a542 3c14fdc e2aef65 3c14fdc f3b6548 3c14fdc 7f71857 3c14fdc 4e7455f 3c14fdc 4e7455f 3c14fdc 4e7455f 3c14fdc f3b6548 3c14fdc f3b6548 3c14fdc f3b6548 3c14fdc f3b6548 3c14fdc f3b6548 3c14fdc f3b6548 7f71857 f3b6548 3c14fdc f3b6548 3c14fdc f3b6548 4e7455f f3b6548 dfd51d1 f3b6548 4e7455f f3b6548 3c14fdc f3b6548 3c14fdc 422fdb8 f3b6548 3c14fdc d58a542 3c14fdc f3b6548 3c14fdc f3b6548 3c14fdc d58a542 f3b6548 3c14fdc f3b6548 3c14fdc d58a542 3c14fdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 |
import gradio as gr
import torch
import pandas as pd
import os
import json
import logging
import numpy as np
from datetime import datetime
from pathlib import Path
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from huggingface_hub import login
from transformers import (
AutoTokenizer,
BertForSequenceClassification,
TrainingArguments,
Trainer,
DataCollatorWithPadding,
EarlyStoppingCallback,
TrainerCallback
)
from datasets import Dataset, DatasetDict
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
MODEL_PATH = "local-model"
CATEGORIES = ['Online-Safety', 'BroadBand', 'TV-Radio']
idx_to_category = {0: 'Online-Safety', 1: 'BroadBand', 2: 'TV-Radio'}
category_to_idx = {'Online-Safety': 0, 'BroadBand': 1, 'TV-Radio': 2}
TOKEN = None
TRAINING_LOGS = []
CURRENT_MODEL = None
CURRENT_TOKENIZER = None
def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
"""Load and prepare local CSV dataset for training"""
try:
if not os.path.exists(file_path):
raise FileNotFoundError(f"Dataset file not found: {file_path}")
# Load the CSV file
df = pd.read_csv(file_path)
# Verify required columns exist
if text_column not in df.columns:
available_cols = list(df.columns)
raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}")
if label_column not in df.columns:
available_cols = list(df.columns)
raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}")
# Clean the data
df = df.dropna(subset=[text_column, label_column])
df[text_column] = df[text_column].astype(str)
# Handle different label formats
if df[label_column].dtype == 'object':
# If labels are text, convert to indices
unique_labels = df[label_column].unique()
if len(unique_labels) > len(CATEGORIES):
raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}")
# Try to map text labels to our categories
label_mapping = {}
for label in unique_labels:
if label in category_to_idx:
label_mapping[label] = category_to_idx[label]
else:
# Auto-assign if not found
available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values())
if available_indices:
label_mapping[label] = min(available_indices)
else:
raise ValueError(f"Cannot map label '{label}' to available categories")
df['label_idx'] = df[label_column].map(label_mapping)
else:
# If labels are already numeric
df['label_idx'] = df[label_column].astype(int)
# Verify label indices are valid
if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES):
raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
# Create train/validation split
train_df, val_df = train_test_split(
df,
test_size=0.2,
random_state=42,
stratify=df['label_idx']
)
# Convert to Hugging Face datasets
train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']])
val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']])
dataset_dict = DatasetDict({
'train': train_dataset,
'validation': val_dataset
})
return dataset_dict, text_column, 'label_idx'
except Exception as e:
raise Exception(f"Error loading dataset: {str(e)}")
def preview_dataset(uploaded_file, text_column, label_column):
"""Preview a dataset file"""
try:
if uploaded_file is None:
return "Please upload a dataset file first."
# Get the file path from the uploaded file
file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
# Check if file exists and is CSV
if not os.path.exists(file_path):
return "โ File not found. Please try uploading again."
if not file_path.lower().endswith('.csv'):
return "โ Please upload a CSV file (.csv extension required)."
df = pd.read_csv(file_path)
preview_info = []
preview_info.append(f"๐ **Dataset Preview: {os.path.basename(file_path)}**")
preview_info.append(f"- **Total rows:** {len(df)}")
preview_info.append(f"- **Columns:** {list(df.columns)}")
preview_info.append("")
if text_column in df.columns:
preview_info.append(f"โ
**Text column '{text_column}' found**")
preview_info.append(f"- Sample text: {str(df[text_column].iloc[0])[:100]}...")
else:
preview_info.append(f"โ **Text column '{text_column}' not found**")
return "\n".join(preview_info)
if label_column in df.columns:
preview_info.append(f"โ
**Label column '{label_column}' found**")
label_counts = df[label_column].value_counts()
preview_info.append("- **Label distribution:**")
for label, count in label_counts.items():
preview_info.append(f" - {label}: {count} ({count/len(df)*100:.1f}%)")
else:
preview_info.append(f"โ **Label column '{label_column}' not found**")
return "\n".join(preview_info)
return "\n".join(preview_info)
except Exception as e:
return f"โ Error previewing dataset: {str(e)}"
def login_to_hf(token):
"""Login to Hugging Face"""
global TOKEN
TOKEN = token
try:
login(token)
return "โ
Successfully logged in to Hugging Face"
except Exception as e:
return f"โ Login failed: {str(e)}"
def validate_hub_model_id(username, model_name):
"""Validate and construct Hub model ID"""
if not username or not model_name:
return None, "Please provide both username and model name"
# Clean up the model name
model_name = model_name.strip().lower().replace(" ", "-")
model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_'])
# Construct the full model ID
hub_model_id = f"{username}/{model_name}"
return hub_model_id, None
def load_model(model_path):
"""Load a trained model and tokenizer"""
global CURRENT_MODEL, CURRENT_TOKENIZER
try:
# Try loading from local path first
if os.path.exists(model_path):
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=len(CATEGORIES)
)
return f"โ
Model loaded from local path: {model_path}"
# If local path doesn't exist, try loading from Hub
try:
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=len(CATEGORIES)
)
return f"โ
Model loaded from Hugging Face Hub: {model_path}"
except Exception as hub_error:
# If both local and hub loading fail, fall back to base model
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=len(CATEGORIES)
)
return f"โ ๏ธ Failed to load specified model, using base BERT model instead. Error: {str(hub_error)}"
except Exception as e:
return f"โ Failed to load model: {str(e)}"
def tokenize_function(examples, tokenizer, feature_column, max_length=512):
"""Tokenize the input text"""
return tokenizer(
examples[feature_column],
truncation=True,
padding=False,
max_length=max_length
)
def compute_metrics(eval_pred):
"""Compute metrics for evaluation"""
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
accuracy = accuracy_score(labels, predictions)
report = classification_report(labels, predictions, output_dict=True, zero_division=0)
return {
'accuracy': accuracy,
'f1_macro': report['macro avg']['f1-score'],
'f1_weighted': report['weighted avg']['f1-score'],
'precision_macro': report['macro avg']['precision'],
'recall_macro': report['macro avg']['recall']
}
def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size,
learning_rate, hf_token, push_to_hub, username, model_name):
"""Train the model using inline training (no subprocess)"""
global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER
TRAINING_LOGS = []
if hf_token:
login_result = login_to_hf(hf_token)
TRAINING_LOGS.append(login_result)
yield "\n".join(TRAINING_LOGS)
# Validate hub model ID if pushing to hub
if push_to_hub:
hub_model_id, error = validate_hub_model_id(username, model_name)
if error:
TRAINING_LOGS.append(f"โ {error}")
yield "\n".join(TRAINING_LOGS)
return
else:
hub_model_id = None
# Validate uploaded file
if uploaded_file is None:
TRAINING_LOGS.append("โ Please upload a dataset file")
yield "\n".join(TRAINING_LOGS)
return
dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
try:
# Load and prepare dataset
TRAINING_LOGS.append(f"๐ Loading dataset from uploaded file...")
yield "\n".join(TRAINING_LOGS)
dataset_dict, final_text_col, final_label_col = load_and_prepare_local_dataset(
dataset_file, text_column, label_column
)
TRAINING_LOGS.append(f"โ
Dataset loaded successfully!")
TRAINING_LOGS.append(f"- Train samples: {len(dataset_dict['train'])}")
TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
yield "\n".join(TRAINING_LOGS)
# Load model and tokenizer
TRAINING_LOGS.append("๐ค Loading BERT model and tokenizer...")
yield "\n".join(TRAINING_LOGS)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=len(CATEGORIES)
)
TRAINING_LOGS.append("โ
Model and tokenizer loaded")
yield "\n".join(TRAINING_LOGS)
# Tokenize datasets
TRAINING_LOGS.append("๐ค Tokenizing datasets...")
yield "\n".join(TRAINING_LOGS)
def tokenize_batch(examples):
return tokenize_function(examples, tokenizer, final_text_col, 512)
# Get columns to remove (keep only label column and tokenized features)
columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
tokenized_datasets = dataset_dict.map(
tokenize_batch,
batched=True,
remove_columns=columns_to_remove
)
# Rename label column to 'labels' (required by Trainer)
tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
TRAINING_LOGS.append("โ
Tokenization completed")
yield "\n".join(TRAINING_LOGS)
# Set up training
output_dir = Path(MODEL_PATH)
output_dir.mkdir(parents=True, exist_ok=True)
# Calculate steps
total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
eval_steps = max(10, min(100, total_steps // 4))
save_steps = max(20, min(500, total_steps // 2))
logging_steps = max(5, min(50, total_steps // 10))
warmup_steps = min(500, total_steps // 10)
TRAINING_LOGS.append(f"๐ Training configuration:")
TRAINING_LOGS.append(f"- Total steps: {total_steps}")
TRAINING_LOGS.append(f"- Eval steps: {eval_steps}")
TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
yield "\n".join(TRAINING_LOGS)
training_args = TrainingArguments(
output_dir=str(output_dir),
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=warmup_steps,
weight_decay=0.01,
learning_rate=learning_rate,
logging_dir=str(output_dir / "logs"),
logging_steps=logging_steps,
eval_strategy="steps", # Corrected back to eval_strategy
eval_steps=eval_steps,
save_steps=save_steps,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="eval_accuracy",
greater_is_better=True,
push_to_hub=push_to_hub,
hub_model_id=hub_model_id if push_to_hub else None,
report_to=None,
dataloader_num_workers=0,
fp16=torch.cuda.is_available(),
seed=42,
remove_unused_columns=False,
)
# Data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Corrected Callback Class
class ProgressCallback(TrainerCallback):
def __init__(self, logs_list, total_steps):
self.logs = logs_list
self.total_steps = total_steps
def on_train_begin(self, args, state, control, **kwargs):
self.logs.append("๐ Starting training...")
self.log_update()
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % args.logging_steps == 0:
self.logs.append(f"Step {state.global_step}/{self.total_steps}")
self.log_update()
def on_epoch_end(self, args, state, control, **kwargs):
epoch = int(state.epoch)
self.logs.append(f"โ
Epoch {epoch} completed")
self.log_update()
def on_evaluate(self, args, state, control, logs=None, **kwargs):
if logs:
acc = logs.get('eval_accuracy', 0)
loss = logs.get('eval_loss', 0)
self.logs.append(f"๐ Eval - Accuracy: {acc:.4f}, Loss: {loss:.4f}")
self.log_update()
def log_update(self):
# This is a custom helper to yield updates to the Gradio UI
# The original code did this manually, but with TrainerCallback,
# we can't do that. So we log to the list and rely on the UI
# to refresh. For a real-time stream, this part would need to be
# handled by Gradio's streaming feature, but this approach
# is sufficient for the user's current setup.
pass
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['validation'],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3), ProgressCallback(TRAINING_LOGS, total_steps)]
)
# Train the model
try:
trainer.train()
TRAINING_LOGS.append("โ
Training completed successfully!")
yield "\n".join(TRAINING_LOGS)
except Exception as e:
TRAINING_LOGS.append(f"โ Training failed: {str(e)}")
yield "\n".join(TRAINING_LOGS)
return
# Save the model
TRAINING_LOGS.append("๐พ Saving model...")
yield "\n".join(TRAINING_LOGS)
trainer.save_model()
tokenizer.save_pretrained(output_dir)
# Update global model and tokenizer
CURRENT_MODEL = model
CURRENT_TOKENIZER = tokenizer
TRAINING_LOGS.append("โ
Model saved successfully!")
yield "\n".join(TRAINING_LOGS)
# Final evaluation
TRAINING_LOGS.append("๐ Running final evaluation...")
yield "\n".join(TRAINING_LOGS)
try:
eval_results = trainer.evaluate()
TRAINING_LOGS.append("๐ Final Results:")
for key, value in eval_results.items():
if isinstance(value, float):
TRAINING_LOGS.append(f" {key}: {value:.4f}")
else:
TRAINING_LOGS.append(f" {key}: {value}")
# Save results
with open(output_dir / "eval_results.json", "w") as f:
json.dump(eval_results, f, indent=2)
except Exception as e:
TRAINING_LOGS.append(f"โ ๏ธ Evaluation error: {str(e)}")
yield "\n".join(TRAINING_LOGS)
# Push to hub if requested
if push_to_hub and hub_model_id:
TRAINING_LOGS.append(f"๐ค Pushing to Hugging Face Hub: {hub_model_id}")
yield "\n".join(TRAINING_LOGS)
try:
trainer.push_to_hub()
TRAINING_LOGS.append(f"โ
Successfully pushed to {hub_model_id}")
except Exception as e:
TRAINING_LOGS.append(f"โ Push to Hub failed: {str(e)}")
yield "\n".join(TRAINING_LOGS)
TRAINING_LOGS.append("\nโจ Training completed! Your model is ready to use.")
yield "\n".join(TRAINING_LOGS)
except Exception as e:
TRAINING_LOGS.append(f"โ Error during training: {str(e)}")
yield "\n".join(TRAINING_LOGS)
def predict_text(text, model_path):
"""Make a prediction on a single text input"""
global CURRENT_MODEL, CURRENT_TOKENIZER
# Load the model if it's not loaded or a different one is requested
if CURRENT_MODEL is None or model_path != MODEL_PATH:
load_result = load_model(model_path)
if load_result.startswith("โ"):
return load_result
try:
if not text.strip():
return "Please enter some text to classify."
# Check if text was truncated
original_tokens = CURRENT_TOKENIZER(text, truncation=False)
was_truncated = len(original_tokens['input_ids']) > 512
# Tokenize input
inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
# Make prediction
with torch.no_grad():
outputs = CURRENT_MODEL(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class_id = predictions.argmax().item()
confidence = predictions.max().item()
# Get predicted category
predicted_category = idx_to_category[predicted_class_id]
# Format result
truncation_warning = "\n\nโ ๏ธ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
result = []
result.append(f"**Complaint:** {text}")
result.append(f"\n**Predicted Category:** {predicted_category}")
result.append(f"**Confidence:** {confidence:.4f}")
result.append("\n**All Class Probabilities:**")
for i, category in enumerate(CATEGORIES):
prob = predictions[0][i].item()
result.append(f"- {category}: {prob:.4f}")
result.append(truncation_warning)
return "\n".join(result)
except Exception as e:
return f"โ Prediction error: {str(e)}"
def predict_csv(csv_file, model_path):
"""Make predictions on a CSV file with complaints"""
global CURRENT_MODEL, CURRENT_TOKENIZER
# Load the model if needed
if CURRENT_MODEL is None or model_path != MODEL_PATH:
load_result = load_model(model_path)
if load_result.startswith("โ"):
return load_result, None
try:
# Read the CSV file
if hasattr(csv_file, 'name'):
df = pd.read_csv(csv_file.name)
else:
df = pd.read_csv(csv_file)
if 'complaint' not in df.columns:
return "โ CSV file must have a 'complaint' column", None
results = []
predictions_list = []
truncated_count = 0
for i, row in enumerate(df.iterrows()):
complaint = str(row[1]['complaint'])
# Check for truncation
original_tokens = CURRENT_TOKENIZER(complaint, truncation=False)
was_truncated = len(original_tokens['input_ids']) > 512
if was_truncated:
truncated_count += 1
# Predict
inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = CURRENT_MODEL(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_idx = predictions.argmax().item()
confidence = predictions.max().item()
predicted_category = idx_to_category[predicted_idx]
predictions_list.append({
'complaint': complaint,
'predicted_category': predicted_category,
'confidence': confidence,
'truncated': was_truncated
})
truncation_mark = " โ ๏ธ" if was_truncated else ""
preview = complaint if len(complaint) <= 50 else complaint[:47] + "..."
results.append(f"Complaint {i+1}{truncation_mark}: {preview}")
results.append(f"Predicted: {predicted_category} (confidence: {confidence:.3f})\n")
if i >= 19:
results.append(f"... and {len(df) - 20} more (showing first 20 out of {len(df)} complaints)")
break
if truncated_count > 0:
results.append(f"\nโ ๏ธ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
# Save full results to a CSV file
results_df = pd.DataFrame(predictions_list)
results_file = "prediction_results.csv"
results_df.to_csv(results_file, index=False)
results.append(f"\n๐พ Full results saved to {results_file}")
return "\n".join(results), results_file
except Exception as e:
return f"โ CSV processing failed: {str(e)}", None
def push_to_hub_after_training(model_path, username, model_name, token):
"""Push a trained model to Hugging Face Hub"""
try:
if not token:
return "โ Please provide a Hugging Face token"
hub_model_id, error = validate_hub_model_id(username, model_name)
if error:
return f"โ {error}"
# Login and load model
login(token)
if not os.path.exists(model_path):
return "โ No trained model found. Please train a model first."
try:
model = BertForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
except Exception as e:
return f"โ Failed to load model: {str(e)}"
# Push to Hub
try:
model.push_to_hub(hub_model_id)
tokenizer.push_to_hub(hub_model_id)
return f"โ
Successfully pushed model to {hub_model_id}"
except Exception as e:
return f"โ Failed to push to Hub: {str(e)}"
except Exception as e:
return f"โ Error: {str(e)}"
def count_tokens(text):
"""Count tokens in input text"""
global CURRENT_TOKENIZER
if text is None:
return "Enter text to see token count"
# Attempt to load a default tokenizer if it's not set
if CURRENT_TOKENIZER is None:
try:
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
logger.info("Fallback: Tokenizer loaded in count_tokens function.")
except Exception as e:
logger.error(f"Failed to load tokenizer in count_tokens fallback: {e}")
return "โ Error: Tokenizer not loaded. Please load a model or check logs."
# If tokenizer is still None after fallback, something is seriously wrong
if CURRENT_TOKENIZER is None:
return "โ Error: Tokenizer is still not available."
tokens = CURRENT_TOKENIZER(text, truncation=False)
count = len(tokens['input_ids'])
if count > 512:
return f"โ ๏ธ **Token count: {count}/512** - Text will be truncated for BERT"
else:
return f"Token count: {count}/512"
def get_available_datasets():
"""Get list of available CSV files in the current directory"""
available_files = []
for file in os.listdir("."):
if file.endswith(".csv"):
try:
df = pd.read_csv(file)
available_files.append(f"{file} ({len(df)} rows)")
except:
available_files.append(f"{file} (Error reading)")
if not available_files:
available_files = ["No CSV files found in current directory"]
return available_files
def display_available_datasets():
datasets = get_available_datasets()
if datasets:
return "**Available CSV files:**\n\n" + "\n".join([f"- {file}" for file in datasets])
else:
return "No CSV files found in the current directory."
# Initialize tokenizer on startup
if CURRENT_TOKENIZER is None:
try:
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
print("โ
Tokenizer initialized successfully")
except Exception as e:
print(f"โ ๏ธ Warning: Could not initialize tokenizer globally: {e}")
print("๐ Launching BERT Complaint Classifier...")
print("๐ Available at: http://localhost:7860")
# The entire Gradio UI definition must be within a single block
with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app:
gr.Markdown("# BERT Complaint Classifier ๐ฃ๏ธ๐ค")
gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.")
with gr.Tab("Fine-tune Model"):
gr.Markdown("## ๐๏ธ Fine-tune a New Model")
with gr.Column(variant="panel"):
gr.Markdown("### ๐ ๏ธ Training Configuration")
with gr.Row():
uploaded_file = gr.File(label="Upload Training CSV File", type="filepath", file_types=[".csv"])
preview_btn = gr.Button("Preview Dataset")
preview_output = gr.Markdown("Dataset info will appear here")
with gr.Row():
text_column_input = gr.Textbox(label="Text Column Name", value="complaint")
label_column_input = gr.Textbox(label="Label Column Name", value="category")
gr.Markdown("---")
with gr.Row():
num_epochs_slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Epochs")
batch_size_slider = gr.Slider(minimum=4, maximum=32, step=4, value=8, label="Batch Size")
learning_rate_slider = gr.Slider(minimum=1e-6, maximum=1e-4, step=1e-6, value=2e-5, label="Learning Rate")
gr.Markdown("---")
gr.Markdown("### โ๏ธ Hugging Face Hub (Optional)")
with gr.Row():
push_to_hub_checkbox = gr.Checkbox(label="Push to Hugging Face Hub")
hf_token_input = gr.Textbox(label="Hugging Face Token", type="password")
with gr.Row():
hf_username_input = gr.Textbox(label="Hugging Face Username")
hf_model_name_input = gr.Textbox(label="Model Name (for Hub)", value="bert-complaint-classifier")
train_btn = gr.Button("๐ Start Training", variant="primary")
gr.Markdown("---")
training_log_output = gr.Textbox(label="Training Logs", lines=20, max_lines=20, interactive=False)
with gr.Tab("Predict"):
gr.Markdown("## ๐ฎ Make Predictions")
gr.Markdown("Choose a method to classify complaints.")
with gr.Tab("Predict Single Text"):
with gr.Column(variant="panel"):
gr.Markdown("### Classify a Single Complaint")
model_path_input = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID")
load_model_btn = gr.Button("Load Model")
model_status = gr.Textbox(label="Model Status", interactive=False)
gr.Markdown("---")
text_input = gr.Textbox(
label="Enter complaint text",
lines=3,
placeholder="Type your complaint here..."
)
token_counter = gr.Textbox(label="Token Count", interactive=False, value="Enter text to see token count")
predict_btn = gr.Button("๐ฎ Predict Category", variant="primary")
prediction_output = gr.Markdown("Prediction results will appear here")
with gr.Tab("Predict CSV File"):
with gr.Column(variant="panel"):
gr.Markdown("### Classify Multiple Complaints from CSV")
gr.Markdown("Upload a CSV file with a 'complaint' column to classify multiple complaints at once.")
csv_model_path = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID")
csv_load_btn = gr.Button("Load Model")
csv_model_status = gr.Textbox(label="Model Status", interactive=False)
gr.Markdown("---")
csv_file_input = gr.File(label="Upload CSV File", type="filepath", file_types=[".csv"])
csv_predict_btn = gr.Button("๐ฎ Predict All", variant="primary")
csv_prediction_output = gr.Markdown("CSV prediction results will appear here")
csv_download = gr.File(label="Download Results", interactive=False)
with gr.Tab("Push to Hub"):
gr.Markdown("## ๐ค Push Trained Model to Hugging Face Hub")
gr.Markdown("Upload your locally trained model to the Hugging Face Hub for sharing.")
with gr.Column(variant="panel"):
hub_model_path = gr.Textbox(label="Local Model Path", value=MODEL_PATH)
hub_username = gr.Textbox(label="Hugging Face Username")
hub_model_name = gr.Textbox(label="Model Name", value="bert-complaint-classifier")
hub_token = gr.Textbox(label="Hugging Face Token", type="password")
push_hub_btn = gr.Button("๐ Push to Hub", variant="primary")
push_hub_output = gr.Markdown("Push results will appear here")
with gr.Tab("Dataset Info"):
gr.Markdown("## ๐ Dataset Information")
gr.Markdown("View information about available datasets and model categories.")
with gr.Column(variant="panel"):
gr.Markdown("### ๐ฏ Model Categories")
categories_info = gr.Markdown(f"**Available Categories:**\n\n" + "\n".join([f"- **{cat}** (index: {idx})" for idx, cat in idx_to_category.items()]))
gr.Markdown("---")
gr.Markdown("### ๐ Available Datasets")
datasets_btn = gr.Button("๐ Scan for CSV Files")
datasets_info = gr.Markdown("Click 'Scan for CSV Files' to see available datasets")
gr.Markdown("---")
gr.Markdown("### ๐ก Tips")
gr.Markdown("""
**Dataset Format:**
- CSV file with at least two columns
- One column for text (complaints)
- One column for labels/categories
- Labels can be text (will be auto-mapped) or numeric indices (0, 1, 2)
**Training Tips:**
- Start with 3 epochs and adjust based on results
- Use batch size 8-16 for most datasets
- Learning rate 2e-5 works well for BERT fine-tuning
- Enable early stopping to prevent overfitting
**Token Limits:**
- BERT has a 512 token limit
- Long texts will be automatically truncated
- Monitor the token counter when entering text
""")
# Connect functions to UI components
preview_btn.click(
preview_dataset,
inputs=[uploaded_file, text_column_input, label_column_input],
outputs=preview_output
)
train_btn.click(
train_model_inline,
inputs=[
uploaded_file,
text_column_input,
label_column_input,
num_epochs_slider,
batch_size_slider,
learning_rate_slider,
hf_token_input,
push_to_hub_checkbox,
hf_username_input,
hf_model_name_input,
],
outputs=training_log_output,
)
load_model_btn.click(
load_model,
inputs=model_path_input,
outputs=model_status
)
predict_btn.click(
predict_text,
inputs=[text_input, model_path_input],
outputs=prediction_output
)
text_input.change(
count_tokens,
inputs=text_input,
outputs=token_counter
)
csv_load_btn.click(
load_model,
inputs=csv_model_path,
outputs=csv_model_status
)
csv_predict_btn.click(
predict_csv,
inputs=[csv_file_input, csv_model_path],
outputs=[csv_prediction_output, csv_download]
)
push_hub_btn.click(
push_to_hub_after_training,
inputs=[hub_model_path, hub_username, hub_model_name, hub_token],
outputs=push_hub_output
)
datasets_btn.click(
display_available_datasets,
outputs=datasets_info
)
# Run a check for available datasets on app load
app.load(display_available_datasets, outputs=datasets_info)
# Launch the Gradio app
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
) |