Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,13 +2,28 @@ import gradio as gr
|
|
| 2 |
import torch
|
| 3 |
import pandas as pd
|
| 4 |
import os
|
| 5 |
-
import tempfile
|
| 6 |
-
import time
|
| 7 |
-
import subprocess
|
| 8 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from huggingface_hub import login, HfApi
|
| 10 |
-
from transformers import
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Global variables
|
| 14 |
MODEL_PATH = "local-model"
|
|
@@ -20,38 +35,6 @@ TRAINING_LOGS = []
|
|
| 20 |
CURRENT_MODEL = None
|
| 21 |
CURRENT_TOKENIZER = None
|
| 22 |
|
| 23 |
-
# Local data files
|
| 24 |
-
LOCAL_DATA_FILES = [
|
| 25 |
-
"merged-test-data.csv",
|
| 26 |
-
"test-category.csv",
|
| 27 |
-
"test-complaint.csv"
|
| 28 |
-
]
|
| 29 |
-
|
| 30 |
-
def get_available_datasets():
|
| 31 |
-
"""Get list of available local datasets"""
|
| 32 |
-
available_files = []
|
| 33 |
-
for file in LOCAL_DATA_FILES:
|
| 34 |
-
if os.path.exists(file):
|
| 35 |
-
try:
|
| 36 |
-
df = pd.read_csv(file)
|
| 37 |
-
available_files.append(f"{file} ({len(df)} rows)")
|
| 38 |
-
except Exception as e:
|
| 39 |
-
available_files.append(f"{file} (Error: {str(e)})")
|
| 40 |
-
else:
|
| 41 |
-
available_files.append(f"{file} (Not found)")
|
| 42 |
-
|
| 43 |
-
# Also check for any other CSV files in the directory
|
| 44 |
-
for file in os.listdir("."):
|
| 45 |
-
if file.endswith(".csv") and file not in LOCAL_DATA_FILES:
|
| 46 |
-
if os.path.exists(file):
|
| 47 |
-
try:
|
| 48 |
-
df = pd.read_csv(file)
|
| 49 |
-
available_files.append(f"{file} ({len(df)} rows)")
|
| 50 |
-
except:
|
| 51 |
-
available_files.append(f"{file} (Error reading)")
|
| 52 |
-
|
| 53 |
-
return available_files
|
| 54 |
-
|
| 55 |
def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
|
| 56 |
"""Load and prepare local CSV dataset for training"""
|
| 57 |
try:
|
|
@@ -104,8 +87,6 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
|
|
| 104 |
raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
|
| 105 |
|
| 106 |
# Create train/validation split
|
| 107 |
-
from sklearn.model_selection import train_test_split
|
| 108 |
-
|
| 109 |
train_df, val_df = train_test_split(
|
| 110 |
df,
|
| 111 |
test_size=test_size,
|
|
@@ -224,6 +205,262 @@ def load_model(model_path):
|
|
| 224 |
except Exception as e:
|
| 225 |
return f"❌ Failed to load model: {str(e)}"
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
def predict_text(text, model_path):
|
| 228 |
"""Make a prediction on a single text input"""
|
| 229 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
|
@@ -235,25 +472,45 @@ def predict_text(text, model_path):
|
|
| 235 |
return load_result
|
| 236 |
|
| 237 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
# Tokenize input
|
| 239 |
inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
|
| 240 |
|
| 241 |
# Make prediction
|
| 242 |
with torch.no_grad():
|
| 243 |
outputs = CURRENT_MODEL(**inputs)
|
| 244 |
-
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
# Get category
|
| 247 |
-
predicted_category = idx_to_category[
|
| 248 |
|
| 249 |
-
#
|
| 250 |
-
original_tokens = CURRENT_TOKENIZER(text, truncation=False)
|
| 251 |
-
was_truncated = len(original_tokens['input_ids']) > 512
|
| 252 |
truncation_warning = "\n\n⚠️ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
|
| 253 |
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
except Exception as e:
|
| 256 |
-
return f"❌ Prediction
|
| 257 |
|
| 258 |
def predict_csv(csv_file, model_path):
|
| 259 |
"""Make predictions on a CSV file with complaints"""
|
|
@@ -276,6 +533,7 @@ def predict_csv(csv_file, model_path):
|
|
| 276 |
return "❌ CSV file must have a 'complaint' column"
|
| 277 |
|
| 278 |
results = []
|
|
|
|
| 279 |
truncated_count = 0
|
| 280 |
|
| 281 |
for i, row in enumerate(df.iterrows()):
|
|
@@ -291,13 +549,22 @@ def predict_csv(csv_file, model_path):
|
|
| 291 |
inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
|
| 292 |
with torch.no_grad():
|
| 293 |
outputs = CURRENT_MODEL(**inputs)
|
| 294 |
-
|
|
|
|
|
|
|
| 295 |
|
| 296 |
predicted_category = idx_to_category[predicted_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
truncation_mark = " ⚠️" if was_truncated else ""
|
| 299 |
preview = complaint if len(complaint) <= 50 else complaint[:47] + "..."
|
| 300 |
-
results.append(f"Complaint {i+1}{truncation_mark}: {preview}
|
|
|
|
| 301 |
|
| 302 |
if i >= 19:
|
| 303 |
results.append(f"... and {len(df) - 20} more (showing first 20 out of {len(df)} complaints)")
|
|
@@ -306,141 +573,17 @@ def predict_csv(csv_file, model_path):
|
|
| 306 |
if truncated_count > 0:
|
| 307 |
results.append(f"\n⚠️ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
return "\n".join(results)
|
|
|
|
| 310 |
except Exception as e:
|
| 311 |
return f"❌ CSV processing failed: {str(e)}"
|
| 312 |
|
| 313 |
-
def train_model(uploaded_file, text_column, label_column, num_epochs, batch_size,
|
| 314 |
-
learning_rate, hf_token, push_to_hub, username, model_name):
|
| 315 |
-
"""Start the model training process with local data"""
|
| 316 |
-
global TRAINING_LOGS, MODEL_PATH
|
| 317 |
-
|
| 318 |
-
TRAINING_LOGS = [] # Reset logs at the start of training
|
| 319 |
-
|
| 320 |
-
if hf_token:
|
| 321 |
-
login_result = login_to_hf(hf_token)
|
| 322 |
-
TRAINING_LOGS.append(login_result)
|
| 323 |
-
yield "\n".join(TRAINING_LOGS)
|
| 324 |
-
|
| 325 |
-
# Validate hub model ID if pushing to hub
|
| 326 |
-
if push_to_hub:
|
| 327 |
-
hub_model_id, error = validate_hub_model_id(username, model_name)
|
| 328 |
-
if error:
|
| 329 |
-
TRAINING_LOGS.append(f"❌ {error}")
|
| 330 |
-
yield "\n".join(TRAINING_LOGS)
|
| 331 |
-
return
|
| 332 |
-
else:
|
| 333 |
-
hub_model_id = None
|
| 334 |
-
|
| 335 |
-
# Validate uploaded file
|
| 336 |
-
if uploaded_file is None:
|
| 337 |
-
TRAINING_LOGS.append("❌ Please upload a dataset file")
|
| 338 |
-
yield "\n".join(TRAINING_LOGS)
|
| 339 |
-
return
|
| 340 |
-
|
| 341 |
-
# Get the file path from the uploaded file
|
| 342 |
-
dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
|
| 343 |
-
|
| 344 |
-
try:
|
| 345 |
-
# Load and prepare the dataset
|
| 346 |
-
TRAINING_LOGS.append(f"📊 Loading dataset from uploaded file...")
|
| 347 |
-
yield "\n".join(TRAINING_LOGS)
|
| 348 |
-
|
| 349 |
-
dataset_dict, final_text_col, final_label_col = load_and_prepare_local_dataset(
|
| 350 |
-
dataset_file, text_column, label_column
|
| 351 |
-
)
|
| 352 |
-
|
| 353 |
-
TRAINING_LOGS.append(f"✅ Dataset loaded successfully!")
|
| 354 |
-
TRAINING_LOGS.append(f"- Train samples: {len(dataset_dict['train'])}")
|
| 355 |
-
TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
|
| 356 |
-
yield "\n".join(TRAINING_LOGS)
|
| 357 |
-
|
| 358 |
-
# Save dataset temporarily for the training script
|
| 359 |
-
temp_dataset_path = "temp_dataset"
|
| 360 |
-
os.makedirs(temp_dataset_path, exist_ok=True)
|
| 361 |
-
dataset_dict.save_to_disk(temp_dataset_path)
|
| 362 |
-
|
| 363 |
-
TRAINING_LOGS.append("💾 Dataset prepared for training...")
|
| 364 |
-
yield "\n".join(TRAINING_LOGS)
|
| 365 |
-
|
| 366 |
-
except Exception as e:
|
| 367 |
-
TRAINING_LOGS.append(f"❌ Error preparing dataset: {str(e)}")
|
| 368 |
-
yield "\n".join(TRAINING_LOGS)
|
| 369 |
-
return
|
| 370 |
-
|
| 371 |
-
# Create training command for local dataset
|
| 372 |
-
cmd = [
|
| 373 |
-
"python", "bert_finetune.py",
|
| 374 |
-
"--dataset_path", temp_dataset_path, # Use local path instead of HF dataset name
|
| 375 |
-
"--model_id", "bert-base-uncased",
|
| 376 |
-
"--output_dir", MODEL_PATH,
|
| 377 |
-
"--feature_column", final_text_col,
|
| 378 |
-
"--label_column", final_label_col,
|
| 379 |
-
"--num_labels", "3",
|
| 380 |
-
"--num_train_epochs", str(num_epochs),
|
| 381 |
-
"--batch_size", str(batch_size),
|
| 382 |
-
"--learning_rate", str(learning_rate),
|
| 383 |
-
"--max_length", "512"
|
| 384 |
-
]
|
| 385 |
-
|
| 386 |
-
if push_to_hub and hub_model_id:
|
| 387 |
-
cmd.extend(["--push_to_hub", "--hub_model_id", hub_model_id])
|
| 388 |
-
if hf_token:
|
| 389 |
-
cmd.extend(["--hf_token", hf_token])
|
| 390 |
-
|
| 391 |
-
TRAINING_LOGS.append(f"🚀 Starting training with command: {' '.join(cmd)}")
|
| 392 |
-
yield "\n".join(TRAINING_LOGS)
|
| 393 |
-
|
| 394 |
-
try:
|
| 395 |
-
process = subprocess.Popen(
|
| 396 |
-
cmd,
|
| 397 |
-
stdout=subprocess.PIPE,
|
| 398 |
-
stderr=subprocess.STDOUT,
|
| 399 |
-
universal_newlines=True,
|
| 400 |
-
bufsize=1
|
| 401 |
-
)
|
| 402 |
-
|
| 403 |
-
TRAINING_LOGS.append("🔄 Training started...")
|
| 404 |
-
yield "\n".join(TRAINING_LOGS)
|
| 405 |
-
|
| 406 |
-
while True:
|
| 407 |
-
line = process.stdout.readline()
|
| 408 |
-
if not line and process.poll() is not None:
|
| 409 |
-
break
|
| 410 |
-
if line:
|
| 411 |
-
TRAINING_LOGS.append(line.strip())
|
| 412 |
-
yield "\n".join(TRAINING_LOGS)
|
| 413 |
-
|
| 414 |
-
process.wait()
|
| 415 |
-
|
| 416 |
-
if process.returncode == 0:
|
| 417 |
-
TRAINING_LOGS.append("✅ Training completed successfully!")
|
| 418 |
-
if push_to_hub and hub_model_id:
|
| 419 |
-
TRAINING_LOGS.append(f"🤗 Model pushed to Hugging Face Hub: {hub_model_id}")
|
| 420 |
-
|
| 421 |
-
# Load the trained model
|
| 422 |
-
TRAINING_LOGS.append("📥 Loading trained model...")
|
| 423 |
-
load_result = load_model(MODEL_PATH)
|
| 424 |
-
TRAINING_LOGS.append(load_result)
|
| 425 |
-
|
| 426 |
-
# Clean up temporary files
|
| 427 |
-
import shutil
|
| 428 |
-
try:
|
| 429 |
-
shutil.rmtree(temp_dataset_path)
|
| 430 |
-
TRAINING_LOGS.append("🧹 Cleaned up temporary files")
|
| 431 |
-
except:
|
| 432 |
-
pass
|
| 433 |
-
|
| 434 |
-
# Final success message
|
| 435 |
-
TRAINING_LOGS.append("\n✨ All done! Your model is ready to use.")
|
| 436 |
-
else:
|
| 437 |
-
TRAINING_LOGS.append(f"❌ Training failed with return code {process.returncode}")
|
| 438 |
-
|
| 439 |
-
except Exception as e:
|
| 440 |
-
TRAINING_LOGS.append(f"❌ Error during training: {str(e)}")
|
| 441 |
-
|
| 442 |
-
yield "\n".join(TRAINING_LOGS)
|
| 443 |
-
|
| 444 |
def push_to_hub_after_training(model_path, username, model_name, token):
|
| 445 |
"""Push a trained model to Hugging Face Hub"""
|
| 446 |
try:
|
|
@@ -473,220 +616,96 @@ def push_to_hub_after_training(model_path, username, model_name, token):
|
|
| 473 |
except Exception as e:
|
| 474 |
return f"❌ Error: {str(e)}"
|
| 475 |
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
file_types=[".csv"],
|
| 492 |
-
type="filepath"
|
| 493 |
-
)
|
| 494 |
-
|
| 495 |
-
# Column configuration
|
| 496 |
-
with gr.Row():
|
| 497 |
-
text_column = gr.Textbox(
|
| 498 |
-
label="Text Column Name",
|
| 499 |
-
value="complaint",
|
| 500 |
-
placeholder="e.g., complaint, text, description"
|
| 501 |
-
)
|
| 502 |
-
label_column = gr.Textbox(
|
| 503 |
-
label="Label Column Name",
|
| 504 |
-
value="category",
|
| 505 |
-
placeholder="e.g., category, label, class"
|
| 506 |
-
)
|
| 507 |
-
|
| 508 |
-
# Dataset preview
|
| 509 |
-
preview_btn = gr.Button("📊 Preview Dataset", variant="secondary")
|
| 510 |
-
dataset_preview = gr.Markdown("Upload a dataset file and click 'Preview Dataset' to see its structure.")
|
| 511 |
-
|
| 512 |
-
# Training parameters
|
| 513 |
-
with gr.Row():
|
| 514 |
-
num_epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Epochs")
|
| 515 |
-
batch_size = gr.Slider(minimum=4, maximum=32, value=8, step=4, label="Batch Size")
|
| 516 |
-
learning_rate = gr.Slider(minimum=1e-5, maximum=5e-5, value=2e-5, step=1e-5, label="Learning Rate")
|
| 517 |
-
|
| 518 |
-
with gr.Accordion("Hugging Face Hub Settings", open=False):
|
| 519 |
-
hf_token = gr.Textbox(
|
| 520 |
-
label="Hugging Face Token (required for pushing to Hub)",
|
| 521 |
-
type="password"
|
| 522 |
-
)
|
| 523 |
-
|
| 524 |
-
gr.Markdown("""### Choose when to push to Hub:
|
| 525 |
-
1. During Training: Model will be pushed automatically when training completes
|
| 526 |
-
2. After Training: You can push the trained model manually later""")
|
| 527 |
-
|
| 528 |
-
# During Training Push
|
| 529 |
-
with gr.Group():
|
| 530 |
-
push_to_hub = gr.Checkbox(
|
| 531 |
-
label="Push Model to Hub during training",
|
| 532 |
-
value=False
|
| 533 |
-
)
|
| 534 |
-
|
| 535 |
-
with gr.Column(visible=False) as hub_settings:
|
| 536 |
-
username = gr.Textbox(
|
| 537 |
-
label="Hugging Face Username",
|
| 538 |
-
placeholder="e.g., huggingface-username"
|
| 539 |
-
)
|
| 540 |
-
model_name = gr.Textbox(
|
| 541 |
-
label="Model Name",
|
| 542 |
-
placeholder="e.g., bert-complaint-classifier"
|
| 543 |
-
)
|
| 544 |
-
|
| 545 |
-
# Post-Training Push
|
| 546 |
-
with gr.Group():
|
| 547 |
-
post_train_push = gr.Checkbox(
|
| 548 |
-
label="Push trained model to Hub after training",
|
| 549 |
-
value=False
|
| 550 |
-
)
|
| 551 |
-
|
| 552 |
-
with gr.Column(visible=False) as post_train_settings:
|
| 553 |
-
post_train_username = gr.Textbox(
|
| 554 |
-
label="Hugging Face Username",
|
| 555 |
-
placeholder="e.g., huggingface-username"
|
| 556 |
-
)
|
| 557 |
-
post_train_model_name = gr.Textbox(
|
| 558 |
-
label="Model Name",
|
| 559 |
-
placeholder="e.g., bert-complaint-classifier"
|
| 560 |
-
)
|
| 561 |
-
post_train_token = gr.Textbox(
|
| 562 |
-
label="Hugging Face Token (if different from above)",
|
| 563 |
-
type="password"
|
| 564 |
-
)
|
| 565 |
-
post_train_push_btn = gr.Button(
|
| 566 |
-
"Push Model to Hub",
|
| 567 |
-
variant="secondary"
|
| 568 |
-
)
|
| 569 |
-
post_train_status = gr.Textbox(label="Upload Status")
|
| 570 |
-
|
| 571 |
-
# Show/hide settings based on checkboxes
|
| 572 |
-
push_to_hub.change(
|
| 573 |
-
lambda x: gr.update(visible=x),
|
| 574 |
-
inputs=push_to_hub,
|
| 575 |
-
outputs=hub_settings
|
| 576 |
-
)
|
| 577 |
-
|
| 578 |
-
post_train_push.change(
|
| 579 |
-
lambda x: gr.update(visible=x),
|
| 580 |
-
inputs=post_train_push,
|
| 581 |
-
outputs=post_train_settings
|
| 582 |
-
)
|
| 583 |
|
| 584 |
-
|
| 585 |
-
|
|
|
|
|
|
|
|
|
|
| 586 |
|
| 587 |
-
|
| 588 |
-
|
| 589 |
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
)
|
| 596 |
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
inputs=[
|
| 601 |
-
dataset_file,
|
| 602 |
-
text_column,
|
| 603 |
-
label_column,
|
| 604 |
-
num_epochs,
|
| 605 |
-
batch_size,
|
| 606 |
-
learning_rate,
|
| 607 |
-
hf_token,
|
| 608 |
-
push_to_hub,
|
| 609 |
-
username,
|
| 610 |
-
model_name
|
| 611 |
-
],
|
| 612 |
-
outputs=training_output,
|
| 613 |
-
show_progress="full"
|
| 614 |
-
)
|
| 615 |
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
],
|
| 625 |
-
outputs=post_train_status
|
| 626 |
-
)
|
| 627 |
-
|
| 628 |
-
# Classification Tab
|
| 629 |
-
with gr.TabItem("Classify Complaints"):
|
| 630 |
-
gr.Markdown("### Classify Customer Complaints")
|
| 631 |
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
placeholder="e.g., local-model or your-username/bert-complaint-classifier"
|
| 636 |
-
)
|
| 637 |
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
)
|
| 646 |
-
|
| 647 |
-
classify_btn = gr.Button("Classify", variant="primary")
|
| 648 |
-
token_info = gr.Markdown("Note: BERT has a 512 token limit. Longer complaints will be truncated.")
|
| 649 |
-
text_output = gr.Textbox(label="Classification Result", lines=5)
|
| 650 |
-
|
| 651 |
-
# Token counter
|
| 652 |
-
def count_tokens(text):
|
| 653 |
-
if not text or CURRENT_TOKENIZER is None:
|
| 654 |
-
return "Enter text to see token count"
|
| 655 |
-
tokens = CURRENT_TOKENIZER(text, truncation=False)
|
| 656 |
-
count = len(tokens['input_ids'])
|
| 657 |
-
if count > 512:
|
| 658 |
-
return f"⚠️ **Token count: {count}/512** - Text will be truncated for BERT"
|
| 659 |
-
else:
|
| 660 |
-
return f"Token count: {count}/512"
|
| 661 |
-
|
| 662 |
-
text_input.change(
|
| 663 |
-
fn=count_tokens,
|
| 664 |
-
inputs=text_input,
|
| 665 |
-
outputs=token_info
|
| 666 |
-
)
|
| 667 |
-
|
| 668 |
-
classify_btn.click(
|
| 669 |
-
predict_text,
|
| 670 |
-
inputs=[text_input, model_path],
|
| 671 |
-
outputs=text_output
|
| 672 |
-
)
|
| 673 |
-
|
| 674 |
-
# Batch Processing
|
| 675 |
-
with gr.TabItem("Batch Processing"):
|
| 676 |
-
gr.Markdown("Upload a CSV file with a 'complaint' column")
|
| 677 |
-
csv_input = gr.File(label="Upload CSV", file_types=[".csv"])
|
| 678 |
-
batch_classify_btn = gr.Button("Classify All", variant="primary")
|
| 679 |
-
csv_output = gr.Textbox(label="Classification Results", lines=15)
|
| 680 |
-
|
| 681 |
-
batch_classify_btn.click(
|
| 682 |
-
predict_csv,
|
| 683 |
-
inputs=[csv_input, model_path],
|
| 684 |
-
outputs=csv_output
|
| 685 |
-
)
|
| 686 |
|
| 687 |
# Launch the app
|
| 688 |
if __name__ == "__main__":
|
| 689 |
# Initialize tokenizer on startup
|
| 690 |
if CURRENT_TOKENIZER is None:
|
| 691 |
-
|
| 692 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
import pandas as pd
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
| 5 |
import json
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from sklearn.metrics import accuracy_score, classification_report
|
| 11 |
+
from sklearn.model_selection import train_test_split
|
| 12 |
+
|
| 13 |
from huggingface_hub import login, HfApi
|
| 14 |
+
from transformers import (
|
| 15 |
+
AutoTokenizer,
|
| 16 |
+
BertForSequenceClassification,
|
| 17 |
+
TrainingArguments,
|
| 18 |
+
Trainer,
|
| 19 |
+
DataCollatorWithPadding,
|
| 20 |
+
EarlyStoppingCallback
|
| 21 |
+
)
|
| 22 |
+
from datasets import Dataset, DatasetDict
|
| 23 |
+
|
| 24 |
+
# Set up logging
|
| 25 |
+
logging.basicConfig(level=logging.INFO)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
# Global variables
|
| 29 |
MODEL_PATH = "local-model"
|
|
|
|
| 35 |
CURRENT_MODEL = None
|
| 36 |
CURRENT_TOKENIZER = None
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
|
| 39 |
"""Load and prepare local CSV dataset for training"""
|
| 40 |
try:
|
|
|
|
| 87 |
raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
|
| 88 |
|
| 89 |
# Create train/validation split
|
|
|
|
|
|
|
| 90 |
train_df, val_df = train_test_split(
|
| 91 |
df,
|
| 92 |
test_size=test_size,
|
|
|
|
| 205 |
except Exception as e:
|
| 206 |
return f"❌ Failed to load model: {str(e)}"
|
| 207 |
|
| 208 |
+
def tokenize_function(examples, tokenizer, feature_column, max_length=512):
|
| 209 |
+
"""Tokenize the input text"""
|
| 210 |
+
return tokenizer(
|
| 211 |
+
examples[feature_column],
|
| 212 |
+
truncation=True,
|
| 213 |
+
padding=False,
|
| 214 |
+
max_length=max_length
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def compute_metrics(eval_pred):
|
| 218 |
+
"""Compute metrics for evaluation"""
|
| 219 |
+
predictions, labels = eval_pred
|
| 220 |
+
predictions = np.argmax(predictions, axis=1)
|
| 221 |
+
|
| 222 |
+
accuracy = accuracy_score(labels, predictions)
|
| 223 |
+
report = classification_report(labels, predictions, output_dict=True, zero_division=0)
|
| 224 |
+
|
| 225 |
+
return {
|
| 226 |
+
'accuracy': accuracy,
|
| 227 |
+
'f1_macro': report['macro avg']['f1-score'],
|
| 228 |
+
'f1_weighted': report['weighted avg']['f1-score'],
|
| 229 |
+
'precision_macro': report['macro avg']['precision'],
|
| 230 |
+
'recall_macro': report['macro avg']['recall']
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size,
|
| 234 |
+
learning_rate, hf_token, push_to_hub, username, model_name):
|
| 235 |
+
"""Train the model using inline training (no subprocess)"""
|
| 236 |
+
global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER
|
| 237 |
+
|
| 238 |
+
TRAINING_LOGS = []
|
| 239 |
+
|
| 240 |
+
if hf_token:
|
| 241 |
+
login_result = login_to_hf(hf_token)
|
| 242 |
+
TRAINING_LOGS.append(login_result)
|
| 243 |
+
yield "\n".join(TRAINING_LOGS)
|
| 244 |
+
|
| 245 |
+
# Validate hub model ID if pushing to hub
|
| 246 |
+
if push_to_hub:
|
| 247 |
+
hub_model_id, error = validate_hub_model_id(username, model_name)
|
| 248 |
+
if error:
|
| 249 |
+
TRAINING_LOGS.append(f"❌ {error}")
|
| 250 |
+
yield "\n".join(TRAINING_LOGS)
|
| 251 |
+
return
|
| 252 |
+
else:
|
| 253 |
+
hub_model_id = None
|
| 254 |
+
|
| 255 |
+
# Validate uploaded file
|
| 256 |
+
if uploaded_file is None:
|
| 257 |
+
TRAINING_LOGS.append("❌ Please upload a dataset file")
|
| 258 |
+
yield "\n".join(TRAINING_LOGS)
|
| 259 |
+
return
|
| 260 |
+
|
| 261 |
+
dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
# Load and prepare dataset
|
| 265 |
+
TRAINING_LOGS.append(f"📊 Loading dataset from uploaded file...")
|
| 266 |
+
yield "\n".join(TRAINING_LOGS)
|
| 267 |
+
|
| 268 |
+
dataset_dict, final_text_col, final_label_col = load_and_prepare_local_dataset(
|
| 269 |
+
dataset_file, text_column, label_column
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
TRAINING_LOGS.append(f"✅ Dataset loaded successfully!")
|
| 273 |
+
TRAINING_LOGS.append(f"- Train samples: {len(dataset_dict['train'])}")
|
| 274 |
+
TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
|
| 275 |
+
yield "\n".join(TRAINING_LOGS)
|
| 276 |
+
|
| 277 |
+
# Load model and tokenizer
|
| 278 |
+
TRAINING_LOGS.append("🤖 Loading BERT model and tokenizer...")
|
| 279 |
+
yield "\n".join(TRAINING_LOGS)
|
| 280 |
+
|
| 281 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 282 |
+
model = BertForSequenceClassification.from_pretrained(
|
| 283 |
+
"bert-base-uncased",
|
| 284 |
+
num_labels=len(CATEGORIES)
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
TRAINING_LOGS.append("✅ Model and tokenizer loaded")
|
| 288 |
+
yield "\n".join(TRAINING_LOGS)
|
| 289 |
+
|
| 290 |
+
# Tokenize datasets
|
| 291 |
+
TRAINING_LOGS.append("🔤 Tokenizing datasets...")
|
| 292 |
+
yield "\n".join(TRAINING_LOGS)
|
| 293 |
+
|
| 294 |
+
def tokenize_batch(examples):
|
| 295 |
+
return tokenize_function(examples, tokenizer, final_text_col, 512)
|
| 296 |
+
|
| 297 |
+
# Get columns to remove (keep only label column and tokenized features)
|
| 298 |
+
columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
|
| 299 |
+
|
| 300 |
+
tokenized_datasets = dataset_dict.map(
|
| 301 |
+
tokenize_batch,
|
| 302 |
+
batched=True,
|
| 303 |
+
remove_columns=columns_to_remove
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Rename label column to 'labels' (required by Trainer)
|
| 307 |
+
tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
|
| 308 |
+
|
| 309 |
+
TRAINING_LOGS.append("✅ Tokenization completed")
|
| 310 |
+
yield "\n".join(TRAINING_LOGS)
|
| 311 |
+
|
| 312 |
+
# Set up training
|
| 313 |
+
output_dir = Path(MODEL_PATH)
|
| 314 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 315 |
+
|
| 316 |
+
# Calculate steps
|
| 317 |
+
total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
|
| 318 |
+
eval_steps = max(10, min(100, total_steps // 4))
|
| 319 |
+
save_steps = max(20, min(500, total_steps // 2))
|
| 320 |
+
logging_steps = max(5, min(50, total_steps // 10))
|
| 321 |
+
warmup_steps = min(500, total_steps // 10)
|
| 322 |
+
|
| 323 |
+
TRAINING_LOGS.append(f"📈 Training configuration:")
|
| 324 |
+
TRAINING_LOGS.append(f"- Total steps: {total_steps}")
|
| 325 |
+
TRAINING_LOGS.append(f"- Eval steps: {eval_steps}")
|
| 326 |
+
TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
|
| 327 |
+
yield "\n".join(TRAINING_LOGS)
|
| 328 |
+
|
| 329 |
+
# Training arguments
|
| 330 |
+
training_args = TrainingArguments(
|
| 331 |
+
output_dir=str(output_dir),
|
| 332 |
+
num_train_epochs=num_epochs,
|
| 333 |
+
per_device_train_batch_size=batch_size,
|
| 334 |
+
per_device_eval_batch_size=batch_size,
|
| 335 |
+
warmup_steps=warmup_steps,
|
| 336 |
+
weight_decay=0.01,
|
| 337 |
+
learning_rate=learning_rate,
|
| 338 |
+
logging_dir=str(output_dir / "logs"),
|
| 339 |
+
logging_steps=logging_steps,
|
| 340 |
+
eval_strategy="steps",
|
| 341 |
+
eval_steps=eval_steps,
|
| 342 |
+
save_steps=save_steps,
|
| 343 |
+
save_total_limit=2,
|
| 344 |
+
load_best_model_at_end=True,
|
| 345 |
+
metric_for_best_model="eval_accuracy",
|
| 346 |
+
greater_is_better=True,
|
| 347 |
+
push_to_hub=push_to_hub,
|
| 348 |
+
hub_model_id=hub_model_id if push_to_hub else None,
|
| 349 |
+
report_to=None,
|
| 350 |
+
dataloader_num_workers=0,
|
| 351 |
+
fp16=torch.cuda.is_available(),
|
| 352 |
+
seed=42,
|
| 353 |
+
remove_unused_columns=False,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Data collator
|
| 357 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 358 |
+
|
| 359 |
+
# Create trainer
|
| 360 |
+
trainer = Trainer(
|
| 361 |
+
model=model,
|
| 362 |
+
args=training_args,
|
| 363 |
+
train_dataset=tokenized_datasets['train'],
|
| 364 |
+
eval_dataset=tokenized_datasets['validation'],
|
| 365 |
+
tokenizer=tokenizer,
|
| 366 |
+
data_collator=data_collator,
|
| 367 |
+
compute_metrics=compute_metrics,
|
| 368 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
TRAINING_LOGS.append("🚀 Starting training...")
|
| 372 |
+
yield "\n".join(TRAINING_LOGS)
|
| 373 |
+
|
| 374 |
+
# Custom training loop with progress updates
|
| 375 |
+
class ProgressCallback:
|
| 376 |
+
def __init__(self, logs_list):
|
| 377 |
+
self.logs = logs_list
|
| 378 |
+
self.step_count = 0
|
| 379 |
+
|
| 380 |
+
def on_step_end(self, args, state, control, model=None, **kwargs):
|
| 381 |
+
self.step_count += 1
|
| 382 |
+
if self.step_count % logging_steps == 0:
|
| 383 |
+
self.logs.append(f"Step {self.step_count}/{total_steps}")
|
| 384 |
+
|
| 385 |
+
def on_epoch_end(self, args, state, control, model=None, **kwargs):
|
| 386 |
+
epoch = int(state.epoch)
|
| 387 |
+
self.logs.append(f"✅ Epoch {epoch} completed")
|
| 388 |
+
|
| 389 |
+
def on_evaluate(self, args, state, control, model=None, logs=None, **kwargs):
|
| 390 |
+
if logs:
|
| 391 |
+
acc = logs.get('eval_accuracy', 0)
|
| 392 |
+
loss = logs.get('eval_loss', 0)
|
| 393 |
+
self.logs.append(f"📊 Eval - Accuracy: {acc:.4f}, Loss: {loss:.4f}")
|
| 394 |
+
|
| 395 |
+
progress_callback = ProgressCallback(TRAINING_LOGS)
|
| 396 |
+
trainer.add_callback(progress_callback)
|
| 397 |
+
|
| 398 |
+
# Train the model
|
| 399 |
+
try:
|
| 400 |
+
trainer.train()
|
| 401 |
+
TRAINING_LOGS.append("✅ Training completed successfully!")
|
| 402 |
+
yield "\n".join(TRAINING_LOGS)
|
| 403 |
+
except Exception as e:
|
| 404 |
+
TRAINING_LOGS.append(f"❌ Training failed: {str(e)}")
|
| 405 |
+
yield "\n".join(TRAINING_LOGS)
|
| 406 |
+
return
|
| 407 |
+
|
| 408 |
+
# Save the model
|
| 409 |
+
TRAINING_LOGS.append("💾 Saving model...")
|
| 410 |
+
yield "\n".join(TRAINING_LOGS)
|
| 411 |
+
|
| 412 |
+
trainer.save_model()
|
| 413 |
+
tokenizer.save_pretrained(output_dir)
|
| 414 |
+
|
| 415 |
+
# Update global model and tokenizer
|
| 416 |
+
CURRENT_MODEL = model
|
| 417 |
+
CURRENT_TOKENIZER = tokenizer
|
| 418 |
+
|
| 419 |
+
TRAINING_LOGS.append("✅ Model saved successfully!")
|
| 420 |
+
yield "\n".join(TRAINING_LOGS)
|
| 421 |
+
|
| 422 |
+
# Final evaluation
|
| 423 |
+
TRAINING_LOGS.append("📊 Running final evaluation...")
|
| 424 |
+
yield "\n".join(TRAINING_LOGS)
|
| 425 |
+
|
| 426 |
+
try:
|
| 427 |
+
eval_results = trainer.evaluate()
|
| 428 |
+
TRAINING_LOGS.append("📊 Final Results:")
|
| 429 |
+
for key, value in eval_results.items():
|
| 430 |
+
if isinstance(value, float):
|
| 431 |
+
TRAINING_LOGS.append(f" {key}: {value:.4f}")
|
| 432 |
+
else:
|
| 433 |
+
TRAINING_LOGS.append(f" {key}: {value}")
|
| 434 |
+
|
| 435 |
+
# Save results
|
| 436 |
+
with open(output_dir / "eval_results.json", "w") as f:
|
| 437 |
+
json.dump(eval_results, f, indent=2)
|
| 438 |
+
|
| 439 |
+
except Exception as e:
|
| 440 |
+
TRAINING_LOGS.append(f"⚠️ Evaluation error: {str(e)}")
|
| 441 |
+
|
| 442 |
+
yield "\n".join(TRAINING_LOGS)
|
| 443 |
+
|
| 444 |
+
# Push to hub if requested
|
| 445 |
+
if push_to_hub and hub_model_id:
|
| 446 |
+
TRAINING_LOGS.append(f"🤗 Pushing to Hugging Face Hub: {hub_model_id}")
|
| 447 |
+
yield "\n".join(TRAINING_LOGS)
|
| 448 |
+
|
| 449 |
+
try:
|
| 450 |
+
trainer.push_to_hub()
|
| 451 |
+
TRAINING_LOGS.append(f"✅ Successfully pushed to {hub_model_id}")
|
| 452 |
+
except Exception as e:
|
| 453 |
+
TRAINING_LOGS.append(f"❌ Push to Hub failed: {str(e)}")
|
| 454 |
+
|
| 455 |
+
yield "\n".join(TRAINING_LOGS)
|
| 456 |
+
|
| 457 |
+
TRAINING_LOGS.append("\n✨ Training completed! Your model is ready to use.")
|
| 458 |
+
yield "\n".join(TRAINING_LOGS)
|
| 459 |
+
|
| 460 |
+
except Exception as e:
|
| 461 |
+
TRAINING_LOGS.append(f"❌ Error during training: {str(e)}")
|
| 462 |
+
yield "\n".join(TRAINING_LOGS)
|
| 463 |
+
|
| 464 |
def predict_text(text, model_path):
|
| 465 |
"""Make a prediction on a single text input"""
|
| 466 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
|
|
|
| 472 |
return load_result
|
| 473 |
|
| 474 |
try:
|
| 475 |
+
if not text.strip():
|
| 476 |
+
return "Please enter some text to classify."
|
| 477 |
+
|
| 478 |
+
# Check if text was truncated
|
| 479 |
+
original_tokens = CURRENT_TOKENIZER(text, truncation=False)
|
| 480 |
+
was_truncated = len(original_tokens['input_ids']) > 512
|
| 481 |
+
|
| 482 |
# Tokenize input
|
| 483 |
inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
|
| 484 |
|
| 485 |
# Make prediction
|
| 486 |
with torch.no_grad():
|
| 487 |
outputs = CURRENT_MODEL(**inputs)
|
| 488 |
+
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 489 |
+
predicted_class_id = predictions.argmax().item()
|
| 490 |
+
confidence = predictions.max().item()
|
| 491 |
|
| 492 |
+
# Get predicted category
|
| 493 |
+
predicted_category = idx_to_category[predicted_class_id]
|
| 494 |
|
| 495 |
+
# Format result
|
|
|
|
|
|
|
| 496 |
truncation_warning = "\n\n⚠️ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
|
| 497 |
|
| 498 |
+
result = []
|
| 499 |
+
result.append(f"**Complaint:** {text}")
|
| 500 |
+
result.append(f"\n**Predicted Category:** {predicted_category}")
|
| 501 |
+
result.append(f"**Confidence:** {confidence:.4f}")
|
| 502 |
+
result.append("\n**All Class Probabilities:**")
|
| 503 |
+
|
| 504 |
+
for i, category in enumerate(CATEGORIES):
|
| 505 |
+
prob = predictions[0][i].item()
|
| 506 |
+
result.append(f"- {category}: {prob:.4f}")
|
| 507 |
+
|
| 508 |
+
result.append(truncation_warning)
|
| 509 |
+
|
| 510 |
+
return "\n".join(result)
|
| 511 |
+
|
| 512 |
except Exception as e:
|
| 513 |
+
return f"❌ Prediction error: {str(e)}"
|
| 514 |
|
| 515 |
def predict_csv(csv_file, model_path):
|
| 516 |
"""Make predictions on a CSV file with complaints"""
|
|
|
|
| 533 |
return "❌ CSV file must have a 'complaint' column"
|
| 534 |
|
| 535 |
results = []
|
| 536 |
+
predictions_list = []
|
| 537 |
truncated_count = 0
|
| 538 |
|
| 539 |
for i, row in enumerate(df.iterrows()):
|
|
|
|
| 549 |
inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
|
| 550 |
with torch.no_grad():
|
| 551 |
outputs = CURRENT_MODEL(**inputs)
|
| 552 |
+
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 553 |
+
predicted_idx = predictions.argmax().item()
|
| 554 |
+
confidence = predictions.max().item()
|
| 555 |
|
| 556 |
predicted_category = idx_to_category[predicted_idx]
|
| 557 |
+
predictions_list.append({
|
| 558 |
+
'complaint': complaint,
|
| 559 |
+
'predicted_category': predicted_category,
|
| 560 |
+
'confidence': confidence,
|
| 561 |
+
'truncated': was_truncated
|
| 562 |
+
})
|
| 563 |
|
| 564 |
truncation_mark = " ⚠️" if was_truncated else ""
|
| 565 |
preview = complaint if len(complaint) <= 50 else complaint[:47] + "..."
|
| 566 |
+
results.append(f"Complaint {i+1}{truncation_mark}: {preview}")
|
| 567 |
+
results.append(f"Predicted: {predicted_category} (confidence: {confidence:.3f})\n")
|
| 568 |
|
| 569 |
if i >= 19:
|
| 570 |
results.append(f"... and {len(df) - 20} more (showing first 20 out of {len(df)} complaints)")
|
|
|
|
| 573 |
if truncated_count > 0:
|
| 574 |
results.append(f"\n⚠️ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
|
| 575 |
|
| 576 |
+
# Save full results to a CSV file
|
| 577 |
+
results_df = pd.DataFrame(predictions_list)
|
| 578 |
+
results_file = "prediction_results.csv"
|
| 579 |
+
results_df.to_csv(results_file, index=False)
|
| 580 |
+
results.append(f"\n💾 Full results saved to {results_file}")
|
| 581 |
+
|
| 582 |
return "\n".join(results)
|
| 583 |
+
|
| 584 |
except Exception as e:
|
| 585 |
return f"❌ CSV processing failed: {str(e)}"
|
| 586 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
def push_to_hub_after_training(model_path, username, model_name, token):
|
| 588 |
"""Push a trained model to Hugging Face Hub"""
|
| 589 |
try:
|
|
|
|
| 616 |
except Exception as e:
|
| 617 |
return f"❌ Error: {str(e)}"
|
| 618 |
|
| 619 |
+
def count_tokens(text):
|
| 620 |
+
"""Count tokens in input text"""
|
| 621 |
+
if not text or CURRENT_TOKENIZER is None:
|
| 622 |
+
return "Enter text to see token count"
|
| 623 |
+
tokens = CURRENT_TOKENIZER(text, truncation=False)
|
| 624 |
+
count = len(tokens['input_ids'])
|
| 625 |
+
if count > 512:
|
| 626 |
+
return f"⚠️ **Token count: {count}/512** - Text will be truncated for BERT"
|
| 627 |
+
else:
|
| 628 |
+
return f"Token count: {count}/512"
|
| 629 |
+
|
| 630 |
+
def get_available_datasets():
|
| 631 |
+
"""Get list of available CSV files in the current directory"""
|
| 632 |
+
available_files = []
|
| 633 |
+
for file in os.listdir("."):
|
| 634 |
+
if file.endswith(".csv"):
|
| 635 |
+
try:
|
| 636 |
+
df = pd.read_csv(file)
|
| 637 |
+
available_files.append(f"{file} ({len(df)} rows)")
|
| 638 |
+
except:
|
| 639 |
+
available_files.append(f"{file} (Error reading)")
|
| 640 |
|
| 641 |
+
if not available_files:
|
| 642 |
+
available_files = ["No CSV files found in current directory"]
|
| 643 |
+
|
| 644 |
+
return available_files
|
| 645 |
+
def display_available_datasets():
|
| 646 |
+
datasets = get_available_datasets()
|
| 647 |
+
if datasets:
|
| 648 |
+
return "**Available CSV files:**\n\n" + "\n".join([f"- {file}" for file in datasets])
|
| 649 |
+
else:
|
| 650 |
+
return "No CSV files found in the current directory."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
|
| 652 |
+
# Initialize the display
|
| 653 |
+
refresh_datasets_btn.click(
|
| 654 |
+
display_available_datasets,
|
| 655 |
+
outputs=available_datasets
|
| 656 |
+
)
|
| 657 |
|
| 658 |
+
# Show datasets on load
|
| 659 |
+
app.load(display_available_datasets, outputs=available_datasets)
|
| 660 |
|
| 661 |
+
gr.Markdown("### Dataset Format Requirements")
|
| 662 |
+
gr.Markdown("""
|
| 663 |
+
**For training, your CSV file should have:**
|
| 664 |
+
- A text column containing the complaint text (default name: 'complaint')
|
| 665 |
+
- A label column containing categories (default name: 'category')
|
|
|
|
| 666 |
|
| 667 |
+
**Supported label formats:**
|
| 668 |
+
- Text labels: 'Online-Safety', 'BroadBand', 'TV-Radio'
|
| 669 |
+
- Numeric labels: 0, 1, 2 (corresponding to the categories above)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 670 |
|
| 671 |
+
**Example CSV structure:**
|
| 672 |
+
```
|
| 673 |
+
complaint,category
|
| 674 |
+
"My internet is slow",BroadBand
|
| 675 |
+
"Blocked website access",Online-Safety
|
| 676 |
+
"Poor TV signal",TV-Radio
|
| 677 |
+
```
|
| 678 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
|
| 680 |
+
gr.Markdown("### Model Categories")
|
| 681 |
+
categories_info = f"""
|
| 682 |
+
**The model classifies complaints into these categories:**
|
|
|
|
|
|
|
| 683 |
|
| 684 |
+
| Index | Category | Description |
|
| 685 |
+
|-------|----------|-------------|
|
| 686 |
+
| 0 | Online-Safety | Internet safety, content filtering, cybersecurity issues |
|
| 687 |
+
| 1 | BroadBand | Internet connectivity, speed, network problems |
|
| 688 |
+
| 2 | TV-Radio | Television and radio broadcasting, signal quality issues |
|
| 689 |
+
"""
|
| 690 |
+
gr.Markdown(categories_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 691 |
|
| 692 |
# Launch the app
|
| 693 |
if __name__ == "__main__":
|
| 694 |
# Initialize tokenizer on startup
|
| 695 |
if CURRENT_TOKENIZER is None:
|
| 696 |
+
try:
|
| 697 |
+
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 698 |
+
print("✅ Tokenizer initialized successfully")
|
| 699 |
+
except Exception as e:
|
| 700 |
+
print(f"⚠️ Warning: Could not initialize tokenizer: {e}")
|
| 701 |
+
|
| 702 |
+
print("🚀 Launching BERT Complaint Classifier...")
|
| 703 |
+
print("📍 Available at: http://localhost:7860")
|
| 704 |
+
|
| 705 |
+
app.launch(
|
| 706 |
+
server_name="0.0.0.0",
|
| 707 |
+
server_port=7860,
|
| 708 |
+
share=False,
|
| 709 |
+
show_error=True
|
| 710 |
+
)
|
| 711 |
+
|