Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,16 +35,16 @@ TRAINING_LOGS = []
|
|
| 35 |
CURRENT_MODEL = None
|
| 36 |
CURRENT_TOKENIZER = None
|
| 37 |
|
|
|
|
|
|
|
| 38 |
def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
|
| 39 |
"""Load and prepare local CSV dataset for training"""
|
| 40 |
try:
|
| 41 |
if not os.path.exists(file_path):
|
| 42 |
raise FileNotFoundError(f"Dataset file not found: {file_path}")
|
| 43 |
|
| 44 |
-
# Load the CSV file
|
| 45 |
df = pd.read_csv(file_path)
|
| 46 |
|
| 47 |
-
# Verify required columns exist
|
| 48 |
if text_column not in df.columns:
|
| 49 |
available_cols = list(df.columns)
|
| 50 |
raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}")
|
|
@@ -53,24 +53,19 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
|
|
| 53 |
available_cols = list(df.columns)
|
| 54 |
raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}")
|
| 55 |
|
| 56 |
-
# Clean the data
|
| 57 |
df = df.dropna(subset=[text_column, label_column])
|
| 58 |
df[text_column] = df[text_column].astype(str)
|
| 59 |
|
| 60 |
-
# Handle different label formats
|
| 61 |
if df[label_column].dtype == 'object':
|
| 62 |
-
# If labels are text, convert to indices
|
| 63 |
unique_labels = df[label_column].unique()
|
| 64 |
if len(unique_labels) > len(CATEGORIES):
|
| 65 |
raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}")
|
| 66 |
|
| 67 |
-
# Try to map text labels to our categories
|
| 68 |
label_mapping = {}
|
| 69 |
for label in unique_labels:
|
| 70 |
if label in category_to_idx:
|
| 71 |
label_mapping[label] = category_to_idx[label]
|
| 72 |
else:
|
| 73 |
-
# Auto-assign if not found
|
| 74 |
available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values())
|
| 75 |
if available_indices:
|
| 76 |
label_mapping[label] = min(available_indices)
|
|
@@ -79,14 +74,11 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
|
|
| 79 |
|
| 80 |
df['label_idx'] = df[label_column].map(label_mapping)
|
| 81 |
else:
|
| 82 |
-
# If labels are already numeric
|
| 83 |
df['label_idx'] = df[label_column].astype(int)
|
| 84 |
|
| 85 |
-
# Verify label indices are valid
|
| 86 |
if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES):
|
| 87 |
raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
|
| 88 |
|
| 89 |
-
# Create train/validation split
|
| 90 |
train_df, val_df = train_test_split(
|
| 91 |
df,
|
| 92 |
test_size=test_size,
|
|
@@ -94,7 +86,6 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
|
|
| 94 |
stratify=df['label_idx']
|
| 95 |
)
|
| 96 |
|
| 97 |
-
# Convert to Hugging Face datasets
|
| 98 |
train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']])
|
| 99 |
val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']])
|
| 100 |
|
|
@@ -114,7 +105,6 @@ def preview_dataset(uploaded_file, text_column, label_column):
|
|
| 114 |
if uploaded_file is None:
|
| 115 |
return "Please upload a dataset file first."
|
| 116 |
|
| 117 |
-
# Get the file path from the uploaded file
|
| 118 |
file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
|
| 119 |
|
| 120 |
df = pd.read_csv(file_path)
|
|
@@ -162,11 +152,9 @@ def validate_hub_model_id(username, model_name):
|
|
| 162 |
if not username or not model_name:
|
| 163 |
return None, "Please provide both username and model name"
|
| 164 |
|
| 165 |
-
# Clean up the model name
|
| 166 |
model_name = model_name.strip().lower().replace(" ", "-")
|
| 167 |
model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_'])
|
| 168 |
|
| 169 |
-
# Construct the full model ID
|
| 170 |
hub_model_id = f"{username}/{model_name}"
|
| 171 |
|
| 172 |
return hub_model_id, None
|
|
@@ -176,7 +164,6 @@ def load_model(model_path):
|
|
| 176 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 177 |
|
| 178 |
try:
|
| 179 |
-
# Try loading from local path first
|
| 180 |
if os.path.exists(model_path):
|
| 181 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
|
| 182 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
|
@@ -185,7 +172,6 @@ def load_model(model_path):
|
|
| 185 |
)
|
| 186 |
return f"β
Model loaded from local path: {model_path}"
|
| 187 |
|
| 188 |
-
# If local path doesn't exist, try loading from Hub
|
| 189 |
try:
|
| 190 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
|
| 191 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
|
@@ -194,7 +180,6 @@ def load_model(model_path):
|
|
| 194 |
)
|
| 195 |
return f"β
Model loaded from Hugging Face Hub: {model_path}"
|
| 196 |
except Exception as hub_error:
|
| 197 |
-
# If both local and hub loading fail, fall back to base model
|
| 198 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 199 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
| 200 |
"bert-base-uncased",
|
|
@@ -242,7 +227,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 242 |
TRAINING_LOGS.append(login_result)
|
| 243 |
yield "\n".join(TRAINING_LOGS)
|
| 244 |
|
| 245 |
-
# Validate hub model ID if pushing to hub
|
| 246 |
if push_to_hub:
|
| 247 |
hub_model_id, error = validate_hub_model_id(username, model_name)
|
| 248 |
if error:
|
|
@@ -252,7 +236,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 252 |
else:
|
| 253 |
hub_model_id = None
|
| 254 |
|
| 255 |
-
# Validate uploaded file
|
| 256 |
if uploaded_file is None:
|
| 257 |
TRAINING_LOGS.append("β Please upload a dataset file")
|
| 258 |
yield "\n".join(TRAINING_LOGS)
|
|
@@ -261,7 +244,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 261 |
dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
|
| 262 |
|
| 263 |
try:
|
| 264 |
-
# Load and prepare dataset
|
| 265 |
TRAINING_LOGS.append(f"π Loading dataset from uploaded file...")
|
| 266 |
yield "\n".join(TRAINING_LOGS)
|
| 267 |
|
|
@@ -274,7 +256,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 274 |
TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
|
| 275 |
yield "\n".join(TRAINING_LOGS)
|
| 276 |
|
| 277 |
-
# Load model and tokenizer
|
| 278 |
TRAINING_LOGS.append("π€ Loading BERT model and tokenizer...")
|
| 279 |
yield "\n".join(TRAINING_LOGS)
|
| 280 |
|
|
@@ -287,14 +268,12 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 287 |
TRAINING_LOGS.append("β
Model and tokenizer loaded")
|
| 288 |
yield "\n".join(TRAINING_LOGS)
|
| 289 |
|
| 290 |
-
# Tokenize datasets
|
| 291 |
TRAINING_LOGS.append("π€ Tokenizing datasets...")
|
| 292 |
yield "\n".join(TRAINING_LOGS)
|
| 293 |
|
| 294 |
def tokenize_batch(examples):
|
| 295 |
return tokenize_function(examples, tokenizer, final_text_col, 512)
|
| 296 |
|
| 297 |
-
# Get columns to remove (keep only label column and tokenized features)
|
| 298 |
columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
|
| 299 |
|
| 300 |
tokenized_datasets = dataset_dict.map(
|
|
@@ -303,17 +282,14 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 303 |
remove_columns=columns_to_remove
|
| 304 |
)
|
| 305 |
|
| 306 |
-
# Rename label column to 'labels' (required by Trainer)
|
| 307 |
tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
|
| 308 |
|
| 309 |
TRAINING_LOGS.append("β
Tokenization completed")
|
| 310 |
yield "\n".join(TRAINING_LOGS)
|
| 311 |
|
| 312 |
-
# Set up training
|
| 313 |
output_dir = Path(MODEL_PATH)
|
| 314 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 315 |
|
| 316 |
-
# Calculate steps
|
| 317 |
total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
|
| 318 |
eval_steps = max(10, min(100, total_steps // 4))
|
| 319 |
save_steps = max(20, min(500, total_steps // 2))
|
|
@@ -326,7 +302,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 326 |
TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
|
| 327 |
yield "\n".join(TRAINING_LOGS)
|
| 328 |
|
| 329 |
-
# Training arguments
|
| 330 |
training_args = TrainingArguments(
|
| 331 |
output_dir=str(output_dir),
|
| 332 |
num_train_epochs=num_epochs,
|
|
@@ -353,10 +328,8 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 353 |
remove_unused_columns=False,
|
| 354 |
)
|
| 355 |
|
| 356 |
-
# Data collator
|
| 357 |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 358 |
|
| 359 |
-
# Create trainer
|
| 360 |
trainer = Trainer(
|
| 361 |
model=model,
|
| 362 |
args=training_args,
|
|
@@ -371,7 +344,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 371 |
TRAINING_LOGS.append("π Starting training...")
|
| 372 |
yield "\n".join(TRAINING_LOGS)
|
| 373 |
|
| 374 |
-
# Custom training loop with progress updates
|
| 375 |
class ProgressCallback:
|
| 376 |
def __init__(self, logs_list):
|
| 377 |
self.logs = logs_list
|
|
@@ -395,7 +367,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 395 |
progress_callback = ProgressCallback(TRAINING_LOGS)
|
| 396 |
trainer.add_callback(progress_callback)
|
| 397 |
|
| 398 |
-
# Train the model
|
| 399 |
try:
|
| 400 |
trainer.train()
|
| 401 |
TRAINING_LOGS.append("β
Training completed successfully!")
|
|
@@ -405,21 +376,18 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 405 |
yield "\n".join(TRAINING_LOGS)
|
| 406 |
return
|
| 407 |
|
| 408 |
-
# Save the model
|
| 409 |
TRAINING_LOGS.append("πΎ Saving model...")
|
| 410 |
yield "\n".join(TRAINING_LOGS)
|
| 411 |
|
| 412 |
trainer.save_model()
|
| 413 |
tokenizer.save_pretrained(output_dir)
|
| 414 |
|
| 415 |
-
# Update global model and tokenizer
|
| 416 |
CURRENT_MODEL = model
|
| 417 |
CURRENT_TOKENIZER = tokenizer
|
| 418 |
|
| 419 |
TRAINING_LOGS.append("β
Model saved successfully!")
|
| 420 |
yield "\n".join(TRAINING_LOGS)
|
| 421 |
|
| 422 |
-
# Final evaluation
|
| 423 |
TRAINING_LOGS.append("π Running final evaluation...")
|
| 424 |
yield "\n".join(TRAINING_LOGS)
|
| 425 |
|
|
@@ -432,7 +400,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 432 |
else:
|
| 433 |
TRAINING_LOGS.append(f" {key}: {value}")
|
| 434 |
|
| 435 |
-
# Save results
|
| 436 |
with open(output_dir / "eval_results.json", "w") as f:
|
| 437 |
json.dump(eval_results, f, indent=2)
|
| 438 |
|
|
@@ -441,7 +408,6 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 441 |
|
| 442 |
yield "\n".join(TRAINING_LOGS)
|
| 443 |
|
| 444 |
-
# Push to hub if requested
|
| 445 |
if push_to_hub and hub_model_id:
|
| 446 |
TRAINING_LOGS.append(f"π€ Pushing to Hugging Face Hub: {hub_model_id}")
|
| 447 |
yield "\n".join(TRAINING_LOGS)
|
|
@@ -465,7 +431,6 @@ def predict_text(text, model_path):
|
|
| 465 |
"""Make a prediction on a single text input"""
|
| 466 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 467 |
|
| 468 |
-
# Load the model if it's not loaded or a different one is requested
|
| 469 |
if CURRENT_MODEL is None or model_path != MODEL_PATH:
|
| 470 |
load_result = load_model(model_path)
|
| 471 |
if load_result.startswith("β"):
|
|
@@ -475,24 +440,19 @@ def predict_text(text, model_path):
|
|
| 475 |
if not text.strip():
|
| 476 |
return "Please enter some text to classify."
|
| 477 |
|
| 478 |
-
# Check if text was truncated
|
| 479 |
original_tokens = CURRENT_TOKENIZER(text, truncation=False)
|
| 480 |
was_truncated = len(original_tokens['input_ids']) > 512
|
| 481 |
|
| 482 |
-
# Tokenize input
|
| 483 |
inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
|
| 484 |
|
| 485 |
-
# Make prediction
|
| 486 |
with torch.no_grad():
|
| 487 |
outputs = CURRENT_MODEL(**inputs)
|
| 488 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 489 |
predicted_class_id = predictions.argmax().item()
|
| 490 |
confidence = predictions.max().item()
|
| 491 |
|
| 492 |
-
# Get predicted category
|
| 493 |
predicted_category = idx_to_category[predicted_class_id]
|
| 494 |
|
| 495 |
-
# Format result
|
| 496 |
truncation_warning = "\n\nβ οΈ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
|
| 497 |
|
| 498 |
result = []
|
|
@@ -516,21 +476,19 @@ def predict_csv(csv_file, model_path):
|
|
| 516 |
"""Make predictions on a CSV file with complaints"""
|
| 517 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 518 |
|
| 519 |
-
# Load the model if needed
|
| 520 |
if CURRENT_MODEL is None or model_path != MODEL_PATH:
|
| 521 |
load_result = load_model(model_path)
|
| 522 |
if load_result.startswith("β"):
|
| 523 |
-
return load_result
|
| 524 |
|
| 525 |
try:
|
| 526 |
-
# Read the CSV file
|
| 527 |
if hasattr(csv_file, 'name'):
|
| 528 |
df = pd.read_csv(csv_file.name)
|
| 529 |
else:
|
| 530 |
df = pd.read_csv(csv_file)
|
| 531 |
|
| 532 |
if 'complaint' not in df.columns:
|
| 533 |
-
return "β CSV file must have a 'complaint' column"
|
| 534 |
|
| 535 |
results = []
|
| 536 |
predictions_list = []
|
|
@@ -539,13 +497,11 @@ def predict_csv(csv_file, model_path):
|
|
| 539 |
for i, row in enumerate(df.iterrows()):
|
| 540 |
complaint = str(row[1]['complaint'])
|
| 541 |
|
| 542 |
-
# Check for truncation
|
| 543 |
original_tokens = CURRENT_TOKENIZER(complaint, truncation=False)
|
| 544 |
was_truncated = len(original_tokens['input_ids']) > 512
|
| 545 |
if was_truncated:
|
| 546 |
truncated_count += 1
|
| 547 |
|
| 548 |
-
# Predict
|
| 549 |
inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
|
| 550 |
with torch.no_grad():
|
| 551 |
outputs = CURRENT_MODEL(**inputs)
|
|
@@ -573,16 +529,15 @@ def predict_csv(csv_file, model_path):
|
|
| 573 |
if truncated_count > 0:
|
| 574 |
results.append(f"\nβ οΈ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
|
| 575 |
|
| 576 |
-
# Save full results to a CSV file
|
| 577 |
results_df = pd.DataFrame(predictions_list)
|
| 578 |
results_file = "prediction_results.csv"
|
| 579 |
results_df.to_csv(results_file, index=False)
|
| 580 |
results.append(f"\nπΎ Full results saved to {results_file}")
|
| 581 |
|
| 582 |
-
return "\n".join(results)
|
| 583 |
|
| 584 |
except Exception as e:
|
| 585 |
-
return f"β CSV processing failed: {str(e)}"
|
| 586 |
|
| 587 |
def push_to_hub_after_training(model_path, username, model_name, token):
|
| 588 |
"""Push a trained model to Hugging Face Hub"""
|
|
@@ -594,7 +549,6 @@ def push_to_hub_after_training(model_path, username, model_name, token):
|
|
| 594 |
if error:
|
| 595 |
return f"β {error}"
|
| 596 |
|
| 597 |
-
# Login and load model
|
| 598 |
login(token)
|
| 599 |
if not os.path.exists(model_path):
|
| 600 |
return "β No trained model found. Please train a model first."
|
|
@@ -605,7 +559,6 @@ def push_to_hub_after_training(model_path, username, model_name, token):
|
|
| 605 |
except Exception as e:
|
| 606 |
return f"β Failed to load model: {str(e)}"
|
| 607 |
|
| 608 |
-
# Push to Hub
|
| 609 |
try:
|
| 610 |
model.push_to_hub(hub_model_id)
|
| 611 |
tokenizer.push_to_hub(hub_model_id)
|
|
@@ -650,6 +603,8 @@ def display_available_datasets():
|
|
| 650 |
else:
|
| 651 |
return "No CSV files found in the current directory."
|
| 652 |
|
|
|
|
|
|
|
| 653 |
# Initialize tokenizer on startup
|
| 654 |
if CURRENT_TOKENIZER is None:
|
| 655 |
try:
|
|
@@ -661,7 +616,7 @@ if CURRENT_TOKENIZER is None:
|
|
| 661 |
print("π Launching BERT Complaint Classifier...")
|
| 662 |
print("π Available at: http://localhost:7860")
|
| 663 |
|
| 664 |
-
# The entire Gradio UI definition must be within
|
| 665 |
with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app:
|
| 666 |
gr.Markdown("# BERT Complaint Classifier π£οΈπ€")
|
| 667 |
gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.")
|
|
@@ -724,7 +679,6 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
|
|
| 724 |
predict_btn = gr.Button("Classify Complaint", variant="primary")
|
| 725 |
single_prediction_output = gr.Markdown("Prediction will appear here...")
|
| 726 |
|
| 727 |
-
# Link token count to text input
|
| 728 |
text_input.change(count_tokens, inputs=text_input, outputs=token_count_output)
|
| 729 |
|
| 730 |
with gr.Tab("Predict from CSV"):
|
|
@@ -740,7 +694,6 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
|
|
| 740 |
csv_prediction_output = gr.Markdown("Predictions will appear here...")
|
| 741 |
download_link = gr.File(label="Download Full Predictions", interactive=False)
|
| 742 |
|
| 743 |
-
# Link prediction buttons to functions
|
| 744 |
predict_btn.click(
|
| 745 |
predict_text,
|
| 746 |
inputs=[text_input, model_path_input],
|
|
@@ -797,16 +750,14 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
|
|
| 797 |
hub_token_input_push = gr.Textbox(label="Hugging Face Token", type="password")
|
| 798 |
|
| 799 |
push_btn = gr.Button("π Push Model to Hub", variant="primary")
|
| 800 |
-
push_output = gr.
|
| 801 |
|
| 802 |
-
# Link the push button
|
| 803 |
push_btn.click(
|
| 804 |
push_to_hub_after_training,
|
| 805 |
inputs=[gr.Textbox(value=MODEL_PATH, visible=False), hub_username_input_push, hub_model_name_input_push, hub_token_input_push],
|
| 806 |
outputs=push_output
|
| 807 |
)
|
| 808 |
|
| 809 |
-
# All button clicks and UI logic now correctly indented within the app block
|
| 810 |
preview_btn.click(
|
| 811 |
preview_dataset,
|
| 812 |
inputs=[uploaded_file, text_column_input, label_column_input],
|
|
@@ -835,10 +786,8 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
|
|
| 835 |
outputs=available_datasets
|
| 836 |
)
|
| 837 |
|
| 838 |
-
# Show datasets on load
|
| 839 |
app.load(display_available_datasets, outputs=available_datasets)
|
| 840 |
|
| 841 |
-
# Launch the app
|
| 842 |
if __name__ == "__main__":
|
| 843 |
app.launch(
|
| 844 |
server_name="0.0.0.0",
|
|
|
|
| 35 |
CURRENT_MODEL = None
|
| 36 |
CURRENT_TOKENIZER = None
|
| 37 |
|
| 38 |
+
# --- Application Logic Functions (No change needed here, they are correctly indented) ---
|
| 39 |
+
|
| 40 |
def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
|
| 41 |
"""Load and prepare local CSV dataset for training"""
|
| 42 |
try:
|
| 43 |
if not os.path.exists(file_path):
|
| 44 |
raise FileNotFoundError(f"Dataset file not found: {file_path}")
|
| 45 |
|
|
|
|
| 46 |
df = pd.read_csv(file_path)
|
| 47 |
|
|
|
|
| 48 |
if text_column not in df.columns:
|
| 49 |
available_cols = list(df.columns)
|
| 50 |
raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}")
|
|
|
|
| 53 |
available_cols = list(df.columns)
|
| 54 |
raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}")
|
| 55 |
|
|
|
|
| 56 |
df = df.dropna(subset=[text_column, label_column])
|
| 57 |
df[text_column] = df[text_column].astype(str)
|
| 58 |
|
|
|
|
| 59 |
if df[label_column].dtype == 'object':
|
|
|
|
| 60 |
unique_labels = df[label_column].unique()
|
| 61 |
if len(unique_labels) > len(CATEGORIES):
|
| 62 |
raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}")
|
| 63 |
|
|
|
|
| 64 |
label_mapping = {}
|
| 65 |
for label in unique_labels:
|
| 66 |
if label in category_to_idx:
|
| 67 |
label_mapping[label] = category_to_idx[label]
|
| 68 |
else:
|
|
|
|
| 69 |
available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values())
|
| 70 |
if available_indices:
|
| 71 |
label_mapping[label] = min(available_indices)
|
|
|
|
| 74 |
|
| 75 |
df['label_idx'] = df[label_column].map(label_mapping)
|
| 76 |
else:
|
|
|
|
| 77 |
df['label_idx'] = df[label_column].astype(int)
|
| 78 |
|
|
|
|
| 79 |
if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES):
|
| 80 |
raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
|
| 81 |
|
|
|
|
| 82 |
train_df, val_df = train_test_split(
|
| 83 |
df,
|
| 84 |
test_size=test_size,
|
|
|
|
| 86 |
stratify=df['label_idx']
|
| 87 |
)
|
| 88 |
|
|
|
|
| 89 |
train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']])
|
| 90 |
val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']])
|
| 91 |
|
|
|
|
| 105 |
if uploaded_file is None:
|
| 106 |
return "Please upload a dataset file first."
|
| 107 |
|
|
|
|
| 108 |
file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
|
| 109 |
|
| 110 |
df = pd.read_csv(file_path)
|
|
|
|
| 152 |
if not username or not model_name:
|
| 153 |
return None, "Please provide both username and model name"
|
| 154 |
|
|
|
|
| 155 |
model_name = model_name.strip().lower().replace(" ", "-")
|
| 156 |
model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_'])
|
| 157 |
|
|
|
|
| 158 |
hub_model_id = f"{username}/{model_name}"
|
| 159 |
|
| 160 |
return hub_model_id, None
|
|
|
|
| 164 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 165 |
|
| 166 |
try:
|
|
|
|
| 167 |
if os.path.exists(model_path):
|
| 168 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
|
| 169 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
|
|
|
| 172 |
)
|
| 173 |
return f"β
Model loaded from local path: {model_path}"
|
| 174 |
|
|
|
|
| 175 |
try:
|
| 176 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
|
| 177 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
|
|
|
| 180 |
)
|
| 181 |
return f"β
Model loaded from Hugging Face Hub: {model_path}"
|
| 182 |
except Exception as hub_error:
|
|
|
|
| 183 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 184 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
| 185 |
"bert-base-uncased",
|
|
|
|
| 227 |
TRAINING_LOGS.append(login_result)
|
| 228 |
yield "\n".join(TRAINING_LOGS)
|
| 229 |
|
|
|
|
| 230 |
if push_to_hub:
|
| 231 |
hub_model_id, error = validate_hub_model_id(username, model_name)
|
| 232 |
if error:
|
|
|
|
| 236 |
else:
|
| 237 |
hub_model_id = None
|
| 238 |
|
|
|
|
| 239 |
if uploaded_file is None:
|
| 240 |
TRAINING_LOGS.append("β Please upload a dataset file")
|
| 241 |
yield "\n".join(TRAINING_LOGS)
|
|
|
|
| 244 |
dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
|
| 245 |
|
| 246 |
try:
|
|
|
|
| 247 |
TRAINING_LOGS.append(f"π Loading dataset from uploaded file...")
|
| 248 |
yield "\n".join(TRAINING_LOGS)
|
| 249 |
|
|
|
|
| 256 |
TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
|
| 257 |
yield "\n".join(TRAINING_LOGS)
|
| 258 |
|
|
|
|
| 259 |
TRAINING_LOGS.append("π€ Loading BERT model and tokenizer...")
|
| 260 |
yield "\n".join(TRAINING_LOGS)
|
| 261 |
|
|
|
|
| 268 |
TRAINING_LOGS.append("β
Model and tokenizer loaded")
|
| 269 |
yield "\n".join(TRAINING_LOGS)
|
| 270 |
|
|
|
|
| 271 |
TRAINING_LOGS.append("π€ Tokenizing datasets...")
|
| 272 |
yield "\n".join(TRAINING_LOGS)
|
| 273 |
|
| 274 |
def tokenize_batch(examples):
|
| 275 |
return tokenize_function(examples, tokenizer, final_text_col, 512)
|
| 276 |
|
|
|
|
| 277 |
columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
|
| 278 |
|
| 279 |
tokenized_datasets = dataset_dict.map(
|
|
|
|
| 282 |
remove_columns=columns_to_remove
|
| 283 |
)
|
| 284 |
|
|
|
|
| 285 |
tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
|
| 286 |
|
| 287 |
TRAINING_LOGS.append("β
Tokenization completed")
|
| 288 |
yield "\n".join(TRAINING_LOGS)
|
| 289 |
|
|
|
|
| 290 |
output_dir = Path(MODEL_PATH)
|
| 291 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 292 |
|
|
|
|
| 293 |
total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
|
| 294 |
eval_steps = max(10, min(100, total_steps // 4))
|
| 295 |
save_steps = max(20, min(500, total_steps // 2))
|
|
|
|
| 302 |
TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
|
| 303 |
yield "\n".join(TRAINING_LOGS)
|
| 304 |
|
|
|
|
| 305 |
training_args = TrainingArguments(
|
| 306 |
output_dir=str(output_dir),
|
| 307 |
num_train_epochs=num_epochs,
|
|
|
|
| 328 |
remove_unused_columns=False,
|
| 329 |
)
|
| 330 |
|
|
|
|
| 331 |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 332 |
|
|
|
|
| 333 |
trainer = Trainer(
|
| 334 |
model=model,
|
| 335 |
args=training_args,
|
|
|
|
| 344 |
TRAINING_LOGS.append("π Starting training...")
|
| 345 |
yield "\n".join(TRAINING_LOGS)
|
| 346 |
|
|
|
|
| 347 |
class ProgressCallback:
|
| 348 |
def __init__(self, logs_list):
|
| 349 |
self.logs = logs_list
|
|
|
|
| 367 |
progress_callback = ProgressCallback(TRAINING_LOGS)
|
| 368 |
trainer.add_callback(progress_callback)
|
| 369 |
|
|
|
|
| 370 |
try:
|
| 371 |
trainer.train()
|
| 372 |
TRAINING_LOGS.append("β
Training completed successfully!")
|
|
|
|
| 376 |
yield "\n".join(TRAINING_LOGS)
|
| 377 |
return
|
| 378 |
|
|
|
|
| 379 |
TRAINING_LOGS.append("πΎ Saving model...")
|
| 380 |
yield "\n".join(TRAINING_LOGS)
|
| 381 |
|
| 382 |
trainer.save_model()
|
| 383 |
tokenizer.save_pretrained(output_dir)
|
| 384 |
|
|
|
|
| 385 |
CURRENT_MODEL = model
|
| 386 |
CURRENT_TOKENIZER = tokenizer
|
| 387 |
|
| 388 |
TRAINING_LOGS.append("β
Model saved successfully!")
|
| 389 |
yield "\n".join(TRAINING_LOGS)
|
| 390 |
|
|
|
|
| 391 |
TRAINING_LOGS.append("π Running final evaluation...")
|
| 392 |
yield "\n".join(TRAINING_LOGS)
|
| 393 |
|
|
|
|
| 400 |
else:
|
| 401 |
TRAINING_LOGS.append(f" {key}: {value}")
|
| 402 |
|
|
|
|
| 403 |
with open(output_dir / "eval_results.json", "w") as f:
|
| 404 |
json.dump(eval_results, f, indent=2)
|
| 405 |
|
|
|
|
| 408 |
|
| 409 |
yield "\n".join(TRAINING_LOGS)
|
| 410 |
|
|
|
|
| 411 |
if push_to_hub and hub_model_id:
|
| 412 |
TRAINING_LOGS.append(f"π€ Pushing to Hugging Face Hub: {hub_model_id}")
|
| 413 |
yield "\n".join(TRAINING_LOGS)
|
|
|
|
| 431 |
"""Make a prediction on a single text input"""
|
| 432 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 433 |
|
|
|
|
| 434 |
if CURRENT_MODEL is None or model_path != MODEL_PATH:
|
| 435 |
load_result = load_model(model_path)
|
| 436 |
if load_result.startswith("β"):
|
|
|
|
| 440 |
if not text.strip():
|
| 441 |
return "Please enter some text to classify."
|
| 442 |
|
|
|
|
| 443 |
original_tokens = CURRENT_TOKENIZER(text, truncation=False)
|
| 444 |
was_truncated = len(original_tokens['input_ids']) > 512
|
| 445 |
|
|
|
|
| 446 |
inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
|
| 447 |
|
|
|
|
| 448 |
with torch.no_grad():
|
| 449 |
outputs = CURRENT_MODEL(**inputs)
|
| 450 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 451 |
predicted_class_id = predictions.argmax().item()
|
| 452 |
confidence = predictions.max().item()
|
| 453 |
|
|
|
|
| 454 |
predicted_category = idx_to_category[predicted_class_id]
|
| 455 |
|
|
|
|
| 456 |
truncation_warning = "\n\nβ οΈ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
|
| 457 |
|
| 458 |
result = []
|
|
|
|
| 476 |
"""Make predictions on a CSV file with complaints"""
|
| 477 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 478 |
|
|
|
|
| 479 |
if CURRENT_MODEL is None or model_path != MODEL_PATH:
|
| 480 |
load_result = load_model(model_path)
|
| 481 |
if load_result.startswith("β"):
|
| 482 |
+
return load_result, None
|
| 483 |
|
| 484 |
try:
|
|
|
|
| 485 |
if hasattr(csv_file, 'name'):
|
| 486 |
df = pd.read_csv(csv_file.name)
|
| 487 |
else:
|
| 488 |
df = pd.read_csv(csv_file)
|
| 489 |
|
| 490 |
if 'complaint' not in df.columns:
|
| 491 |
+
return "β CSV file must have a 'complaint' column", None
|
| 492 |
|
| 493 |
results = []
|
| 494 |
predictions_list = []
|
|
|
|
| 497 |
for i, row in enumerate(df.iterrows()):
|
| 498 |
complaint = str(row[1]['complaint'])
|
| 499 |
|
|
|
|
| 500 |
original_tokens = CURRENT_TOKENIZER(complaint, truncation=False)
|
| 501 |
was_truncated = len(original_tokens['input_ids']) > 512
|
| 502 |
if was_truncated:
|
| 503 |
truncated_count += 1
|
| 504 |
|
|
|
|
| 505 |
inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
|
| 506 |
with torch.no_grad():
|
| 507 |
outputs = CURRENT_MODEL(**inputs)
|
|
|
|
| 529 |
if truncated_count > 0:
|
| 530 |
results.append(f"\nβ οΈ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
|
| 531 |
|
|
|
|
| 532 |
results_df = pd.DataFrame(predictions_list)
|
| 533 |
results_file = "prediction_results.csv"
|
| 534 |
results_df.to_csv(results_file, index=False)
|
| 535 |
results.append(f"\nπΎ Full results saved to {results_file}")
|
| 536 |
|
| 537 |
+
return "\n".join(results), results_file
|
| 538 |
|
| 539 |
except Exception as e:
|
| 540 |
+
return f"β CSV processing failed: {str(e)}", None
|
| 541 |
|
| 542 |
def push_to_hub_after_training(model_path, username, model_name, token):
|
| 543 |
"""Push a trained model to Hugging Face Hub"""
|
|
|
|
| 549 |
if error:
|
| 550 |
return f"β {error}"
|
| 551 |
|
|
|
|
| 552 |
login(token)
|
| 553 |
if not os.path.exists(model_path):
|
| 554 |
return "β No trained model found. Please train a model first."
|
|
|
|
| 559 |
except Exception as e:
|
| 560 |
return f"β Failed to load model: {str(e)}"
|
| 561 |
|
|
|
|
| 562 |
try:
|
| 563 |
model.push_to_hub(hub_model_id)
|
| 564 |
tokenizer.push_to_hub(hub_model_id)
|
|
|
|
| 603 |
else:
|
| 604 |
return "No CSV files found in the current directory."
|
| 605 |
|
| 606 |
+
# --- Gradio UI Definition (Correctly structured) ---
|
| 607 |
+
|
| 608 |
# Initialize tokenizer on startup
|
| 609 |
if CURRENT_TOKENIZER is None:
|
| 610 |
try:
|
|
|
|
| 616 |
print("π Launching BERT Complaint Classifier...")
|
| 617 |
print("π Available at: http://localhost:7860")
|
| 618 |
|
| 619 |
+
# The entire Gradio UI definition must be within this single block
|
| 620 |
with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app:
|
| 621 |
gr.Markdown("# BERT Complaint Classifier π£οΈπ€")
|
| 622 |
gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.")
|
|
|
|
| 679 |
predict_btn = gr.Button("Classify Complaint", variant="primary")
|
| 680 |
single_prediction_output = gr.Markdown("Prediction will appear here...")
|
| 681 |
|
|
|
|
| 682 |
text_input.change(count_tokens, inputs=text_input, outputs=token_count_output)
|
| 683 |
|
| 684 |
with gr.Tab("Predict from CSV"):
|
|
|
|
| 694 |
csv_prediction_output = gr.Markdown("Predictions will appear here...")
|
| 695 |
download_link = gr.File(label="Download Full Predictions", interactive=False)
|
| 696 |
|
|
|
|
| 697 |
predict_btn.click(
|
| 698 |
predict_text,
|
| 699 |
inputs=[text_input, model_path_input],
|
|
|
|
| 750 |
hub_token_input_push = gr.Textbox(label="Hugging Face Token", type="password")
|
| 751 |
|
| 752 |
push_btn = gr.Button("π Push Model to Hub", variant="primary")
|
| 753 |
+
push_output = gr.Textbox(label="Results", lines=3, interactive=False)
|
| 754 |
|
|
|
|
| 755 |
push_btn.click(
|
| 756 |
push_to_hub_after_training,
|
| 757 |
inputs=[gr.Textbox(value=MODEL_PATH, visible=False), hub_username_input_push, hub_model_name_input_push, hub_token_input_push],
|
| 758 |
outputs=push_output
|
| 759 |
)
|
| 760 |
|
|
|
|
| 761 |
preview_btn.click(
|
| 762 |
preview_dataset,
|
| 763 |
inputs=[uploaded_file, text_column_input, label_column_input],
|
|
|
|
| 786 |
outputs=available_datasets
|
| 787 |
)
|
| 788 |
|
|
|
|
| 789 |
app.load(display_available_datasets, outputs=available_datasets)
|
| 790 |
|
|
|
|
| 791 |
if __name__ == "__main__":
|
| 792 |
app.launch(
|
| 793 |
server_name="0.0.0.0",
|