Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,7 +12,7 @@ from sklearn.model_selection import train_test_split
|
|
| 12 |
|
| 13 |
from huggingface_hub import login, HfApi
|
| 14 |
from transformers import (
|
| 15 |
-
AutoTokenizer,
|
| 16 |
BertForSequenceClassification,
|
| 17 |
TrainingArguments,
|
| 18 |
Trainer,
|
|
@@ -35,16 +35,16 @@ TRAINING_LOGS = []
|
|
| 35 |
CURRENT_MODEL = None
|
| 36 |
CURRENT_TOKENIZER = None
|
| 37 |
|
| 38 |
-
# --- Application Logic Functions (No change needed here, they are correctly indented) ---
|
| 39 |
-
|
| 40 |
def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
|
| 41 |
"""Load and prepare local CSV dataset for training"""
|
| 42 |
try:
|
| 43 |
if not os.path.exists(file_path):
|
| 44 |
raise FileNotFoundError(f"Dataset file not found: {file_path}")
|
| 45 |
|
|
|
|
| 46 |
df = pd.read_csv(file_path)
|
| 47 |
|
|
|
|
| 48 |
if text_column not in df.columns:
|
| 49 |
available_cols = list(df.columns)
|
| 50 |
raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}")
|
|
@@ -53,19 +53,24 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
|
|
| 53 |
available_cols = list(df.columns)
|
| 54 |
raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}")
|
| 55 |
|
|
|
|
| 56 |
df = df.dropna(subset=[text_column, label_column])
|
| 57 |
df[text_column] = df[text_column].astype(str)
|
| 58 |
|
|
|
|
| 59 |
if df[label_column].dtype == 'object':
|
|
|
|
| 60 |
unique_labels = df[label_column].unique()
|
| 61 |
if len(unique_labels) > len(CATEGORIES):
|
| 62 |
raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}")
|
| 63 |
|
|
|
|
| 64 |
label_mapping = {}
|
| 65 |
for label in unique_labels:
|
| 66 |
if label in category_to_idx:
|
| 67 |
label_mapping[label] = category_to_idx[label]
|
| 68 |
else:
|
|
|
|
| 69 |
available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values())
|
| 70 |
if available_indices:
|
| 71 |
label_mapping[label] = min(available_indices)
|
|
@@ -74,18 +79,22 @@ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_si
|
|
| 74 |
|
| 75 |
df['label_idx'] = df[label_column].map(label_mapping)
|
| 76 |
else:
|
|
|
|
| 77 |
df['label_idx'] = df[label_column].astype(int)
|
| 78 |
|
|
|
|
| 79 |
if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES):
|
| 80 |
raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
|
| 81 |
|
|
|
|
| 82 |
train_df, val_df = train_test_split(
|
| 83 |
-
df,
|
| 84 |
-
test_size=test_size,
|
| 85 |
-
random_state=42,
|
| 86 |
stratify=df['label_idx']
|
| 87 |
)
|
| 88 |
|
|
|
|
| 89 |
train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']])
|
| 90 |
val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']])
|
| 91 |
|
|
@@ -105,6 +114,7 @@ def preview_dataset(uploaded_file, text_column, label_column):
|
|
| 105 |
if uploaded_file is None:
|
| 106 |
return "Please upload a dataset file first."
|
| 107 |
|
|
|
|
| 108 |
file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
|
| 109 |
|
| 110 |
df = pd.read_csv(file_path)
|
|
@@ -152,9 +162,11 @@ def validate_hub_model_id(username, model_name):
|
|
| 152 |
if not username or not model_name:
|
| 153 |
return None, "Please provide both username and model name"
|
| 154 |
|
|
|
|
| 155 |
model_name = model_name.strip().lower().replace(" ", "-")
|
| 156 |
model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_'])
|
| 157 |
|
|
|
|
| 158 |
hub_model_id = f"{username}/{model_name}"
|
| 159 |
|
| 160 |
return hub_model_id, None
|
|
@@ -164,6 +176,7 @@ def load_model(model_path):
|
|
| 164 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 165 |
|
| 166 |
try:
|
|
|
|
| 167 |
if os.path.exists(model_path):
|
| 168 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
|
| 169 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
|
@@ -172,6 +185,7 @@ def load_model(model_path):
|
|
| 172 |
)
|
| 173 |
return f"โ
Model loaded from local path: {model_path}"
|
| 174 |
|
|
|
|
| 175 |
try:
|
| 176 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
|
| 177 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
|
@@ -180,6 +194,7 @@ def load_model(model_path):
|
|
| 180 |
)
|
| 181 |
return f"โ
Model loaded from Hugging Face Hub: {model_path}"
|
| 182 |
except Exception as hub_error:
|
|
|
|
| 183 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 184 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
| 185 |
"bert-base-uncased",
|
|
@@ -215,7 +230,7 @@ def compute_metrics(eval_pred):
|
|
| 215 |
'recall_macro': report['macro avg']['recall']
|
| 216 |
}
|
| 217 |
|
| 218 |
-
def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size,
|
| 219 |
learning_rate, hf_token, push_to_hub, username, model_name):
|
| 220 |
"""Train the model using inline training (no subprocess)"""
|
| 221 |
global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER
|
|
@@ -227,6 +242,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 227 |
TRAINING_LOGS.append(login_result)
|
| 228 |
yield "\n".join(TRAINING_LOGS)
|
| 229 |
|
|
|
|
| 230 |
if push_to_hub:
|
| 231 |
hub_model_id, error = validate_hub_model_id(username, model_name)
|
| 232 |
if error:
|
|
@@ -236,6 +252,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 236 |
else:
|
| 237 |
hub_model_id = None
|
| 238 |
|
|
|
|
| 239 |
if uploaded_file is None:
|
| 240 |
TRAINING_LOGS.append("โ Please upload a dataset file")
|
| 241 |
yield "\n".join(TRAINING_LOGS)
|
|
@@ -244,6 +261,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 244 |
dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
|
| 245 |
|
| 246 |
try:
|
|
|
|
| 247 |
TRAINING_LOGS.append(f"๐ Loading dataset from uploaded file...")
|
| 248 |
yield "\n".join(TRAINING_LOGS)
|
| 249 |
|
|
@@ -256,6 +274,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 256 |
TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
|
| 257 |
yield "\n".join(TRAINING_LOGS)
|
| 258 |
|
|
|
|
| 259 |
TRAINING_LOGS.append("๐ค Loading BERT model and tokenizer...")
|
| 260 |
yield "\n".join(TRAINING_LOGS)
|
| 261 |
|
|
@@ -268,12 +287,14 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 268 |
TRAINING_LOGS.append("โ
Model and tokenizer loaded")
|
| 269 |
yield "\n".join(TRAINING_LOGS)
|
| 270 |
|
|
|
|
| 271 |
TRAINING_LOGS.append("๐ค Tokenizing datasets...")
|
| 272 |
yield "\n".join(TRAINING_LOGS)
|
| 273 |
|
| 274 |
def tokenize_batch(examples):
|
| 275 |
return tokenize_function(examples, tokenizer, final_text_col, 512)
|
| 276 |
|
|
|
|
| 277 |
columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
|
| 278 |
|
| 279 |
tokenized_datasets = dataset_dict.map(
|
|
@@ -282,14 +303,17 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 282 |
remove_columns=columns_to_remove
|
| 283 |
)
|
| 284 |
|
|
|
|
| 285 |
tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
|
| 286 |
|
| 287 |
TRAINING_LOGS.append("โ
Tokenization completed")
|
| 288 |
yield "\n".join(TRAINING_LOGS)
|
| 289 |
|
|
|
|
| 290 |
output_dir = Path(MODEL_PATH)
|
| 291 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 292 |
|
|
|
|
| 293 |
total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
|
| 294 |
eval_steps = max(10, min(100, total_steps // 4))
|
| 295 |
save_steps = max(20, min(500, total_steps // 2))
|
|
@@ -302,6 +326,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 302 |
TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
|
| 303 |
yield "\n".join(TRAINING_LOGS)
|
| 304 |
|
|
|
|
| 305 |
training_args = TrainingArguments(
|
| 306 |
output_dir=str(output_dir),
|
| 307 |
num_train_epochs=num_epochs,
|
|
@@ -328,8 +353,10 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 328 |
remove_unused_columns=False,
|
| 329 |
)
|
| 330 |
|
|
|
|
| 331 |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 332 |
|
|
|
|
| 333 |
trainer = Trainer(
|
| 334 |
model=model,
|
| 335 |
args=training_args,
|
|
@@ -344,6 +371,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 344 |
TRAINING_LOGS.append("๐ Starting training...")
|
| 345 |
yield "\n".join(TRAINING_LOGS)
|
| 346 |
|
|
|
|
| 347 |
class ProgressCallback:
|
| 348 |
def __init__(self, logs_list):
|
| 349 |
self.logs = logs_list
|
|
@@ -367,6 +395,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 367 |
progress_callback = ProgressCallback(TRAINING_LOGS)
|
| 368 |
trainer.add_callback(progress_callback)
|
| 369 |
|
|
|
|
| 370 |
try:
|
| 371 |
trainer.train()
|
| 372 |
TRAINING_LOGS.append("โ
Training completed successfully!")
|
|
@@ -376,18 +405,21 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 376 |
yield "\n".join(TRAINING_LOGS)
|
| 377 |
return
|
| 378 |
|
|
|
|
| 379 |
TRAINING_LOGS.append("๐พ Saving model...")
|
| 380 |
yield "\n".join(TRAINING_LOGS)
|
| 381 |
|
| 382 |
trainer.save_model()
|
| 383 |
tokenizer.save_pretrained(output_dir)
|
| 384 |
|
|
|
|
| 385 |
CURRENT_MODEL = model
|
| 386 |
CURRENT_TOKENIZER = tokenizer
|
| 387 |
|
| 388 |
TRAINING_LOGS.append("โ
Model saved successfully!")
|
| 389 |
yield "\n".join(TRAINING_LOGS)
|
| 390 |
|
|
|
|
| 391 |
TRAINING_LOGS.append("๐ Running final evaluation...")
|
| 392 |
yield "\n".join(TRAINING_LOGS)
|
| 393 |
|
|
@@ -400,6 +432,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 400 |
else:
|
| 401 |
TRAINING_LOGS.append(f" {key}: {value}")
|
| 402 |
|
|
|
|
| 403 |
with open(output_dir / "eval_results.json", "w") as f:
|
| 404 |
json.dump(eval_results, f, indent=2)
|
| 405 |
|
|
@@ -408,6 +441,7 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 408 |
|
| 409 |
yield "\n".join(TRAINING_LOGS)
|
| 410 |
|
|
|
|
| 411 |
if push_to_hub and hub_model_id:
|
| 412 |
TRAINING_LOGS.append(f"๐ค Pushing to Hugging Face Hub: {hub_model_id}")
|
| 413 |
yield "\n".join(TRAINING_LOGS)
|
|
@@ -431,6 +465,7 @@ def predict_text(text, model_path):
|
|
| 431 |
"""Make a prediction on a single text input"""
|
| 432 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 433 |
|
|
|
|
| 434 |
if CURRENT_MODEL is None or model_path != MODEL_PATH:
|
| 435 |
load_result = load_model(model_path)
|
| 436 |
if load_result.startswith("โ"):
|
|
@@ -440,19 +475,24 @@ def predict_text(text, model_path):
|
|
| 440 |
if not text.strip():
|
| 441 |
return "Please enter some text to classify."
|
| 442 |
|
|
|
|
| 443 |
original_tokens = CURRENT_TOKENIZER(text, truncation=False)
|
| 444 |
was_truncated = len(original_tokens['input_ids']) > 512
|
| 445 |
|
|
|
|
| 446 |
inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
|
| 447 |
|
|
|
|
| 448 |
with torch.no_grad():
|
| 449 |
outputs = CURRENT_MODEL(**inputs)
|
| 450 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 451 |
predicted_class_id = predictions.argmax().item()
|
| 452 |
confidence = predictions.max().item()
|
| 453 |
|
|
|
|
| 454 |
predicted_category = idx_to_category[predicted_class_id]
|
| 455 |
|
|
|
|
| 456 |
truncation_warning = "\n\nโ ๏ธ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
|
| 457 |
|
| 458 |
result = []
|
|
@@ -476,12 +516,14 @@ def predict_csv(csv_file, model_path):
|
|
| 476 |
"""Make predictions on a CSV file with complaints"""
|
| 477 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 478 |
|
|
|
|
| 479 |
if CURRENT_MODEL is None or model_path != MODEL_PATH:
|
| 480 |
load_result = load_model(model_path)
|
| 481 |
if load_result.startswith("โ"):
|
| 482 |
return load_result, None
|
| 483 |
|
| 484 |
try:
|
|
|
|
| 485 |
if hasattr(csv_file, 'name'):
|
| 486 |
df = pd.read_csv(csv_file.name)
|
| 487 |
else:
|
|
@@ -497,11 +539,13 @@ def predict_csv(csv_file, model_path):
|
|
| 497 |
for i, row in enumerate(df.iterrows()):
|
| 498 |
complaint = str(row[1]['complaint'])
|
| 499 |
|
|
|
|
| 500 |
original_tokens = CURRENT_TOKENIZER(complaint, truncation=False)
|
| 501 |
was_truncated = len(original_tokens['input_ids']) > 512
|
| 502 |
if was_truncated:
|
| 503 |
truncated_count += 1
|
| 504 |
|
|
|
|
| 505 |
inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
|
| 506 |
with torch.no_grad():
|
| 507 |
outputs = CURRENT_MODEL(**inputs)
|
|
@@ -529,6 +573,7 @@ def predict_csv(csv_file, model_path):
|
|
| 529 |
if truncated_count > 0:
|
| 530 |
results.append(f"\nโ ๏ธ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
|
| 531 |
|
|
|
|
| 532 |
results_df = pd.DataFrame(predictions_list)
|
| 533 |
results_file = "prediction_results.csv"
|
| 534 |
results_df.to_csv(results_file, index=False)
|
|
@@ -549,6 +594,7 @@ def push_to_hub_after_training(model_path, username, model_name, token):
|
|
| 549 |
if error:
|
| 550 |
return f"โ {error}"
|
| 551 |
|
|
|
|
| 552 |
login(token)
|
| 553 |
if not os.path.exists(model_path):
|
| 554 |
return "โ No trained model found. Please train a model first."
|
|
@@ -559,6 +605,7 @@ def push_to_hub_after_training(model_path, username, model_name, token):
|
|
| 559 |
except Exception as e:
|
| 560 |
return f"โ Failed to load model: {str(e)}"
|
| 561 |
|
|
|
|
| 562 |
try:
|
| 563 |
model.push_to_hub(hub_model_id)
|
| 564 |
tokenizer.push_to_hub(hub_model_id)
|
|
@@ -603,8 +650,6 @@ def display_available_datasets():
|
|
| 603 |
else:
|
| 604 |
return "No CSV files found in the current directory."
|
| 605 |
|
| 606 |
-
# --- Gradio UI Definition (Correctly structured) ---
|
| 607 |
-
|
| 608 |
# Initialize tokenizer on startup
|
| 609 |
if CURRENT_TOKENIZER is None:
|
| 610 |
try:
|
|
@@ -616,7 +661,7 @@ if CURRENT_TOKENIZER is None:
|
|
| 616 |
print("๐ Launching BERT Complaint Classifier...")
|
| 617 |
print("๐ Available at: http://localhost:7860")
|
| 618 |
|
| 619 |
-
# The entire Gradio UI definition must be within
|
| 620 |
with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app:
|
| 621 |
gr.Markdown("# BERT Complaint Classifier ๐ฃ๏ธ๐ค")
|
| 622 |
gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.")
|
|
@@ -666,98 +711,86 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
|
|
| 666 |
with gr.Column(variant="panel"):
|
| 667 |
gr.Markdown("### Classify a Single Complaint")
|
| 668 |
|
| 669 |
-
model_path_input = gr.Textbox(
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
placeholder="e.g., local-model or your_username/your_model"
|
| 673 |
-
)
|
| 674 |
|
| 675 |
-
|
| 676 |
-
text_input = gr.Textbox(label="Complaint Text", lines=3)
|
| 677 |
-
token_count_output = gr.Markdown("Token count: 0/512")
|
| 678 |
|
| 679 |
-
|
| 680 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
|
|
|
|
|
|
| 685 |
with gr.Column(variant="panel"):
|
| 686 |
-
gr.Markdown("### Classify Complaints from
|
| 687 |
-
|
| 688 |
-
csv_model_path = gr.Textbox(
|
| 689 |
-
label="Model Path or Hub ID",
|
| 690 |
-
value="local-model",
|
| 691 |
-
placeholder="e.g., local-model or your_username/your_model"
|
| 692 |
-
)
|
| 693 |
-
csv_predict_btn = gr.Button("Run Predictions on CSV", variant="primary")
|
| 694 |
-
csv_prediction_output = gr.Markdown("Predictions will appear here...")
|
| 695 |
-
download_link = gr.File(label="Download Full Predictions", interactive=False)
|
| 696 |
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
|
|
|
| 720 |
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
|
| 725 |
-
|
| 726 |
-
```
|
| 727 |
-
complaint,category
|
| 728 |
-
"My internet is slow",BroadBand
|
| 729 |
-
"Blocked website access",Online-Safety
|
| 730 |
-
"Poor TV signal",TV-Radio
|
| 731 |
-
```
|
| 732 |
-
""")
|
| 733 |
-
gr.Markdown("### Model Categories")
|
| 734 |
-
categories_info = f"""
|
| 735 |
-
**The model classifies complaints into these categories:**
|
| 736 |
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
with gr.Accordion("Push Local Model to Hub"):
|
| 746 |
-
gr.Markdown("Use this to manually push a locally trained model (`./local-model`) to the Hub.")
|
| 747 |
-
with gr.Row():
|
| 748 |
-
hub_username_input_push = gr.Textbox(label="Hugging Face Username")
|
| 749 |
-
hub_model_name_input_push = gr.Textbox(label="Model Name")
|
| 750 |
-
hub_token_input_push = gr.Textbox(label="Hugging Face Token", type="password")
|
| 751 |
|
| 752 |
-
|
| 753 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
|
| 755 |
-
|
| 756 |
-
push_to_hub_after_training,
|
| 757 |
-
inputs=[gr.Textbox(value=MODEL_PATH, visible=False), hub_username_input_push, hub_model_name_input_push, hub_token_input_push],
|
| 758 |
-
outputs=push_output
|
| 759 |
-
)
|
| 760 |
-
|
| 761 |
preview_btn.click(
|
| 762 |
preview_dataset,
|
| 763 |
inputs=[uploaded_file, text_column_input, label_column_input],
|
|
@@ -781,13 +814,51 @@ with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app
|
|
| 781 |
outputs=training_log_output,
|
| 782 |
)
|
| 783 |
|
| 784 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
display_available_datasets,
|
| 786 |
-
outputs=
|
| 787 |
)
|
| 788 |
|
| 789 |
-
app
|
|
|
|
| 790 |
|
|
|
|
| 791 |
if __name__ == "__main__":
|
| 792 |
app.launch(
|
| 793 |
server_name="0.0.0.0",
|
|
|
|
| 12 |
|
| 13 |
from huggingface_hub import login, HfApi
|
| 14 |
from transformers import (
|
| 15 |
+
AutoTokenizer,
|
| 16 |
BertForSequenceClassification,
|
| 17 |
TrainingArguments,
|
| 18 |
Trainer,
|
|
|
|
| 35 |
CURRENT_MODEL = None
|
| 36 |
CURRENT_TOKENIZER = None
|
| 37 |
|
|
|
|
|
|
|
| 38 |
def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
|
| 39 |
"""Load and prepare local CSV dataset for training"""
|
| 40 |
try:
|
| 41 |
if not os.path.exists(file_path):
|
| 42 |
raise FileNotFoundError(f"Dataset file not found: {file_path}")
|
| 43 |
|
| 44 |
+
# Load the CSV file
|
| 45 |
df = pd.read_csv(file_path)
|
| 46 |
|
| 47 |
+
# Verify required columns exist
|
| 48 |
if text_column not in df.columns:
|
| 49 |
available_cols = list(df.columns)
|
| 50 |
raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}")
|
|
|
|
| 53 |
available_cols = list(df.columns)
|
| 54 |
raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}")
|
| 55 |
|
| 56 |
+
# Clean the data
|
| 57 |
df = df.dropna(subset=[text_column, label_column])
|
| 58 |
df[text_column] = df[text_column].astype(str)
|
| 59 |
|
| 60 |
+
# Handle different label formats
|
| 61 |
if df[label_column].dtype == 'object':
|
| 62 |
+
# If labels are text, convert to indices
|
| 63 |
unique_labels = df[label_column].unique()
|
| 64 |
if len(unique_labels) > len(CATEGORIES):
|
| 65 |
raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}")
|
| 66 |
|
| 67 |
+
# Try to map text labels to our categories
|
| 68 |
label_mapping = {}
|
| 69 |
for label in unique_labels:
|
| 70 |
if label in category_to_idx:
|
| 71 |
label_mapping[label] = category_to_idx[label]
|
| 72 |
else:
|
| 73 |
+
# Auto-assign if not found
|
| 74 |
available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values())
|
| 75 |
if available_indices:
|
| 76 |
label_mapping[label] = min(available_indices)
|
|
|
|
| 79 |
|
| 80 |
df['label_idx'] = df[label_column].map(label_mapping)
|
| 81 |
else:
|
| 82 |
+
# If labels are already numeric
|
| 83 |
df['label_idx'] = df[label_column].astype(int)
|
| 84 |
|
| 85 |
+
# Verify label indices are valid
|
| 86 |
if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES):
|
| 87 |
raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
|
| 88 |
|
| 89 |
+
# Create train/validation split
|
| 90 |
train_df, val_df = train_test_split(
|
| 91 |
+
df,
|
| 92 |
+
test_size=test_size,
|
| 93 |
+
random_state=42,
|
| 94 |
stratify=df['label_idx']
|
| 95 |
)
|
| 96 |
|
| 97 |
+
# Convert to Hugging Face datasets
|
| 98 |
train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']])
|
| 99 |
val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']])
|
| 100 |
|
|
|
|
| 114 |
if uploaded_file is None:
|
| 115 |
return "Please upload a dataset file first."
|
| 116 |
|
| 117 |
+
# Get the file path from the uploaded file
|
| 118 |
file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
|
| 119 |
|
| 120 |
df = pd.read_csv(file_path)
|
|
|
|
| 162 |
if not username or not model_name:
|
| 163 |
return None, "Please provide both username and model name"
|
| 164 |
|
| 165 |
+
# Clean up the model name
|
| 166 |
model_name = model_name.strip().lower().replace(" ", "-")
|
| 167 |
model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_'])
|
| 168 |
|
| 169 |
+
# Construct the full model ID
|
| 170 |
hub_model_id = f"{username}/{model_name}"
|
| 171 |
|
| 172 |
return hub_model_id, None
|
|
|
|
| 176 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 177 |
|
| 178 |
try:
|
| 179 |
+
# Try loading from local path first
|
| 180 |
if os.path.exists(model_path):
|
| 181 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
|
| 182 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
|
|
|
| 185 |
)
|
| 186 |
return f"โ
Model loaded from local path: {model_path}"
|
| 187 |
|
| 188 |
+
# If local path doesn't exist, try loading from Hub
|
| 189 |
try:
|
| 190 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
|
| 191 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
|
|
|
| 194 |
)
|
| 195 |
return f"โ
Model loaded from Hugging Face Hub: {model_path}"
|
| 196 |
except Exception as hub_error:
|
| 197 |
+
# If both local and hub loading fail, fall back to base model
|
| 198 |
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 199 |
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
|
| 200 |
"bert-base-uncased",
|
|
|
|
| 230 |
'recall_macro': report['macro avg']['recall']
|
| 231 |
}
|
| 232 |
|
| 233 |
+
def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size,
|
| 234 |
learning_rate, hf_token, push_to_hub, username, model_name):
|
| 235 |
"""Train the model using inline training (no subprocess)"""
|
| 236 |
global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER
|
|
|
|
| 242 |
TRAINING_LOGS.append(login_result)
|
| 243 |
yield "\n".join(TRAINING_LOGS)
|
| 244 |
|
| 245 |
+
# Validate hub model ID if pushing to hub
|
| 246 |
if push_to_hub:
|
| 247 |
hub_model_id, error = validate_hub_model_id(username, model_name)
|
| 248 |
if error:
|
|
|
|
| 252 |
else:
|
| 253 |
hub_model_id = None
|
| 254 |
|
| 255 |
+
# Validate uploaded file
|
| 256 |
if uploaded_file is None:
|
| 257 |
TRAINING_LOGS.append("โ Please upload a dataset file")
|
| 258 |
yield "\n".join(TRAINING_LOGS)
|
|
|
|
| 261 |
dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
|
| 262 |
|
| 263 |
try:
|
| 264 |
+
# Load and prepare dataset
|
| 265 |
TRAINING_LOGS.append(f"๐ Loading dataset from uploaded file...")
|
| 266 |
yield "\n".join(TRAINING_LOGS)
|
| 267 |
|
|
|
|
| 274 |
TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
|
| 275 |
yield "\n".join(TRAINING_LOGS)
|
| 276 |
|
| 277 |
+
# Load model and tokenizer
|
| 278 |
TRAINING_LOGS.append("๐ค Loading BERT model and tokenizer...")
|
| 279 |
yield "\n".join(TRAINING_LOGS)
|
| 280 |
|
|
|
|
| 287 |
TRAINING_LOGS.append("โ
Model and tokenizer loaded")
|
| 288 |
yield "\n".join(TRAINING_LOGS)
|
| 289 |
|
| 290 |
+
# Tokenize datasets
|
| 291 |
TRAINING_LOGS.append("๐ค Tokenizing datasets...")
|
| 292 |
yield "\n".join(TRAINING_LOGS)
|
| 293 |
|
| 294 |
def tokenize_batch(examples):
|
| 295 |
return tokenize_function(examples, tokenizer, final_text_col, 512)
|
| 296 |
|
| 297 |
+
# Get columns to remove (keep only label column and tokenized features)
|
| 298 |
columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
|
| 299 |
|
| 300 |
tokenized_datasets = dataset_dict.map(
|
|
|
|
| 303 |
remove_columns=columns_to_remove
|
| 304 |
)
|
| 305 |
|
| 306 |
+
# Rename label column to 'labels' (required by Trainer)
|
| 307 |
tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
|
| 308 |
|
| 309 |
TRAINING_LOGS.append("โ
Tokenization completed")
|
| 310 |
yield "\n".join(TRAINING_LOGS)
|
| 311 |
|
| 312 |
+
# Set up training
|
| 313 |
output_dir = Path(MODEL_PATH)
|
| 314 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 315 |
|
| 316 |
+
# Calculate steps
|
| 317 |
total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
|
| 318 |
eval_steps = max(10, min(100, total_steps // 4))
|
| 319 |
save_steps = max(20, min(500, total_steps // 2))
|
|
|
|
| 326 |
TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
|
| 327 |
yield "\n".join(TRAINING_LOGS)
|
| 328 |
|
| 329 |
+
# Training arguments
|
| 330 |
training_args = TrainingArguments(
|
| 331 |
output_dir=str(output_dir),
|
| 332 |
num_train_epochs=num_epochs,
|
|
|
|
| 353 |
remove_unused_columns=False,
|
| 354 |
)
|
| 355 |
|
| 356 |
+
# Data collator
|
| 357 |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 358 |
|
| 359 |
+
# Create trainer
|
| 360 |
trainer = Trainer(
|
| 361 |
model=model,
|
| 362 |
args=training_args,
|
|
|
|
| 371 |
TRAINING_LOGS.append("๐ Starting training...")
|
| 372 |
yield "\n".join(TRAINING_LOGS)
|
| 373 |
|
| 374 |
+
# Custom training loop with progress updates
|
| 375 |
class ProgressCallback:
|
| 376 |
def __init__(self, logs_list):
|
| 377 |
self.logs = logs_list
|
|
|
|
| 395 |
progress_callback = ProgressCallback(TRAINING_LOGS)
|
| 396 |
trainer.add_callback(progress_callback)
|
| 397 |
|
| 398 |
+
# Train the model
|
| 399 |
try:
|
| 400 |
trainer.train()
|
| 401 |
TRAINING_LOGS.append("โ
Training completed successfully!")
|
|
|
|
| 405 |
yield "\n".join(TRAINING_LOGS)
|
| 406 |
return
|
| 407 |
|
| 408 |
+
# Save the model
|
| 409 |
TRAINING_LOGS.append("๐พ Saving model...")
|
| 410 |
yield "\n".join(TRAINING_LOGS)
|
| 411 |
|
| 412 |
trainer.save_model()
|
| 413 |
tokenizer.save_pretrained(output_dir)
|
| 414 |
|
| 415 |
+
# Update global model and tokenizer
|
| 416 |
CURRENT_MODEL = model
|
| 417 |
CURRENT_TOKENIZER = tokenizer
|
| 418 |
|
| 419 |
TRAINING_LOGS.append("โ
Model saved successfully!")
|
| 420 |
yield "\n".join(TRAINING_LOGS)
|
| 421 |
|
| 422 |
+
# Final evaluation
|
| 423 |
TRAINING_LOGS.append("๐ Running final evaluation...")
|
| 424 |
yield "\n".join(TRAINING_LOGS)
|
| 425 |
|
|
|
|
| 432 |
else:
|
| 433 |
TRAINING_LOGS.append(f" {key}: {value}")
|
| 434 |
|
| 435 |
+
# Save results
|
| 436 |
with open(output_dir / "eval_results.json", "w") as f:
|
| 437 |
json.dump(eval_results, f, indent=2)
|
| 438 |
|
|
|
|
| 441 |
|
| 442 |
yield "\n".join(TRAINING_LOGS)
|
| 443 |
|
| 444 |
+
# Push to hub if requested
|
| 445 |
if push_to_hub and hub_model_id:
|
| 446 |
TRAINING_LOGS.append(f"๐ค Pushing to Hugging Face Hub: {hub_model_id}")
|
| 447 |
yield "\n".join(TRAINING_LOGS)
|
|
|
|
| 465 |
"""Make a prediction on a single text input"""
|
| 466 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 467 |
|
| 468 |
+
# Load the model if it's not loaded or a different one is requested
|
| 469 |
if CURRENT_MODEL is None or model_path != MODEL_PATH:
|
| 470 |
load_result = load_model(model_path)
|
| 471 |
if load_result.startswith("โ"):
|
|
|
|
| 475 |
if not text.strip():
|
| 476 |
return "Please enter some text to classify."
|
| 477 |
|
| 478 |
+
# Check if text was truncated
|
| 479 |
original_tokens = CURRENT_TOKENIZER(text, truncation=False)
|
| 480 |
was_truncated = len(original_tokens['input_ids']) > 512
|
| 481 |
|
| 482 |
+
# Tokenize input
|
| 483 |
inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
|
| 484 |
|
| 485 |
+
# Make prediction
|
| 486 |
with torch.no_grad():
|
| 487 |
outputs = CURRENT_MODEL(**inputs)
|
| 488 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 489 |
predicted_class_id = predictions.argmax().item()
|
| 490 |
confidence = predictions.max().item()
|
| 491 |
|
| 492 |
+
# Get predicted category
|
| 493 |
predicted_category = idx_to_category[predicted_class_id]
|
| 494 |
|
| 495 |
+
# Format result
|
| 496 |
truncation_warning = "\n\nโ ๏ธ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
|
| 497 |
|
| 498 |
result = []
|
|
|
|
| 516 |
"""Make predictions on a CSV file with complaints"""
|
| 517 |
global CURRENT_MODEL, CURRENT_TOKENIZER
|
| 518 |
|
| 519 |
+
# Load the model if needed
|
| 520 |
if CURRENT_MODEL is None or model_path != MODEL_PATH:
|
| 521 |
load_result = load_model(model_path)
|
| 522 |
if load_result.startswith("โ"):
|
| 523 |
return load_result, None
|
| 524 |
|
| 525 |
try:
|
| 526 |
+
# Read the CSV file
|
| 527 |
if hasattr(csv_file, 'name'):
|
| 528 |
df = pd.read_csv(csv_file.name)
|
| 529 |
else:
|
|
|
|
| 539 |
for i, row in enumerate(df.iterrows()):
|
| 540 |
complaint = str(row[1]['complaint'])
|
| 541 |
|
| 542 |
+
# Check for truncation
|
| 543 |
original_tokens = CURRENT_TOKENIZER(complaint, truncation=False)
|
| 544 |
was_truncated = len(original_tokens['input_ids']) > 512
|
| 545 |
if was_truncated:
|
| 546 |
truncated_count += 1
|
| 547 |
|
| 548 |
+
# Predict
|
| 549 |
inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
|
| 550 |
with torch.no_grad():
|
| 551 |
outputs = CURRENT_MODEL(**inputs)
|
|
|
|
| 573 |
if truncated_count > 0:
|
| 574 |
results.append(f"\nโ ๏ธ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
|
| 575 |
|
| 576 |
+
# Save full results to a CSV file
|
| 577 |
results_df = pd.DataFrame(predictions_list)
|
| 578 |
results_file = "prediction_results.csv"
|
| 579 |
results_df.to_csv(results_file, index=False)
|
|
|
|
| 594 |
if error:
|
| 595 |
return f"โ {error}"
|
| 596 |
|
| 597 |
+
# Login and load model
|
| 598 |
login(token)
|
| 599 |
if not os.path.exists(model_path):
|
| 600 |
return "โ No trained model found. Please train a model first."
|
|
|
|
| 605 |
except Exception as e:
|
| 606 |
return f"โ Failed to load model: {str(e)}"
|
| 607 |
|
| 608 |
+
# Push to Hub
|
| 609 |
try:
|
| 610 |
model.push_to_hub(hub_model_id)
|
| 611 |
tokenizer.push_to_hub(hub_model_id)
|
|
|
|
| 650 |
else:
|
| 651 |
return "No CSV files found in the current directory."
|
| 652 |
|
|
|
|
|
|
|
| 653 |
# Initialize tokenizer on startup
|
| 654 |
if CURRENT_TOKENIZER is None:
|
| 655 |
try:
|
|
|
|
| 661 |
print("๐ Launching BERT Complaint Classifier...")
|
| 662 |
print("๐ Available at: http://localhost:7860")
|
| 663 |
|
| 664 |
+
# The entire Gradio UI definition must be within a single block
|
| 665 |
with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app:
|
| 666 |
gr.Markdown("# BERT Complaint Classifier ๐ฃ๏ธ๐ค")
|
| 667 |
gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.")
|
|
|
|
| 711 |
with gr.Column(variant="panel"):
|
| 712 |
gr.Markdown("### Classify a Single Complaint")
|
| 713 |
|
| 714 |
+
model_path_input = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID")
|
| 715 |
+
load_model_btn = gr.Button("Load Model")
|
| 716 |
+
model_status = gr.Textbox(label="Model Status", interactive=False)
|
|
|
|
|
|
|
| 717 |
|
| 718 |
+
gr.Markdown("---")
|
|
|
|
|
|
|
| 719 |
|
| 720 |
+
text_input = gr.Textbox(
|
| 721 |
+
label="Enter complaint text",
|
| 722 |
+
lines=3,
|
| 723 |
+
placeholder="Type your complaint here..."
|
| 724 |
+
)
|
| 725 |
+
token_counter = gr.Textbox(label="Token Count", interactive=False, value="Enter text to see token count")
|
| 726 |
|
| 727 |
+
predict_btn = gr.Button("๐ฎ Predict Category", variant="primary")
|
| 728 |
+
|
| 729 |
+
prediction_output = gr.Markdown("Prediction results will appear here")
|
| 730 |
+
|
| 731 |
+
with gr.Tab("Predict CSV File"):
|
| 732 |
with gr.Column(variant="panel"):
|
| 733 |
+
gr.Markdown("### Classify Multiple Complaints from CSV")
|
| 734 |
+
gr.Markdown("Upload a CSV file with a 'complaint' column to classify multiple complaints at once.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
|
| 736 |
+
csv_model_path = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID")
|
| 737 |
+
csv_load_btn = gr.Button("Load Model")
|
| 738 |
+
csv_model_status = gr.Textbox(label="Model Status", interactive=False)
|
| 739 |
+
|
| 740 |
+
gr.Markdown("---")
|
| 741 |
+
|
| 742 |
+
csv_file_input = gr.File(label="Upload CSV File", type="filepath", file_types=["csv"])
|
| 743 |
+
csv_predict_btn = gr.Button("๐ฎ Predict All", variant="primary")
|
| 744 |
+
|
| 745 |
+
csv_prediction_output = gr.Markdown("CSV prediction results will appear here")
|
| 746 |
+
csv_download = gr.File(label="Download Results", interactive=False)
|
| 747 |
+
|
| 748 |
+
with gr.Tab("Push to Hub"):
|
| 749 |
+
gr.Markdown("## ๐ค Push Trained Model to Hugging Face Hub")
|
| 750 |
+
gr.Markdown("Upload your locally trained model to the Hugging Face Hub for sharing.")
|
| 751 |
|
| 752 |
+
with gr.Column(variant="panel"):
|
| 753 |
+
hub_model_path = gr.Textbox(label="Local Model Path", value=MODEL_PATH)
|
| 754 |
+
hub_username = gr.Textbox(label="Hugging Face Username")
|
| 755 |
+
hub_model_name = gr.Textbox(label="Model Name", value="bert-complaint-classifier")
|
| 756 |
+
hub_token = gr.Textbox(label="Hugging Face Token", type="password")
|
| 757 |
+
|
| 758 |
+
push_hub_btn = gr.Button("๐ Push to Hub", variant="primary")
|
| 759 |
+
|
| 760 |
+
push_hub_output = gr.Markdown("Push results will appear here")
|
| 761 |
|
| 762 |
+
with gr.Tab("Dataset Info"):
|
| 763 |
+
gr.Markdown("## ๐ Dataset Information")
|
| 764 |
+
gr.Markdown("View information about available datasets and model categories.")
|
| 765 |
+
|
| 766 |
+
with gr.Column(variant="panel"):
|
| 767 |
+
gr.Markdown("### ๐ฏ Model Categories")
|
| 768 |
+
categories_info = gr.Markdown(f"**Available Categories:**\n\n" + "\n".join([f"- **{cat}** (index: {idx})" for idx, cat in idx_to_category.items()]))
|
| 769 |
+
|
| 770 |
+
gr.Markdown("---")
|
| 771 |
|
| 772 |
+
gr.Markdown("### ๐ Available Datasets")
|
| 773 |
+
datasets_btn = gr.Button("๐ Scan for CSV Files")
|
| 774 |
+
datasets_info = gr.Markdown("Click 'Scan for CSV Files' to see available datasets")
|
| 775 |
|
| 776 |
+
gr.Markdown("---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
|
| 778 |
+
gr.Markdown("### ๐ก Tips")
|
| 779 |
+
gr.Markdown("""
|
| 780 |
+
**Dataset Format:**
|
| 781 |
+
- CSV file with at least two columns
|
| 782 |
+
- One column for text (complaints)
|
| 783 |
+
- One column for labels/categories
|
| 784 |
+
- Labels can be text (will be auto-mapped) or numeric indices (0, 1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
|
| 786 |
+
**Training Tips:**
|
| 787 |
+
- Start with 3 epochs and adjust based on results
|
| 788 |
+
- Use batch size 8-16 for most datasets
|
| 789 |
+
- Learning rate 2e-5 works well for BERT fine-tuning
|
| 790 |
+
- Enable early stopping to prevent overfitting
|
| 791 |
+
""")
|
| 792 |
|
| 793 |
+
# Connect functions to UI components
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 794 |
preview_btn.click(
|
| 795 |
preview_dataset,
|
| 796 |
inputs=[uploaded_file, text_column_input, label_column_input],
|
|
|
|
| 814 |
outputs=training_log_output,
|
| 815 |
)
|
| 816 |
|
| 817 |
+
load_model_btn.click(
|
| 818 |
+
load_model,
|
| 819 |
+
inputs=model_path_input,
|
| 820 |
+
outputs=model_status
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
predict_btn.click(
|
| 824 |
+
predict_text,
|
| 825 |
+
inputs=[text_input, model_path_input],
|
| 826 |
+
outputs=prediction_output
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
text_input.change(
|
| 830 |
+
count_tokens,
|
| 831 |
+
inputs=text_input,
|
| 832 |
+
outputs=token_counter
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
csv_load_btn.click(
|
| 836 |
+
load_model,
|
| 837 |
+
inputs=csv_model_path,
|
| 838 |
+
outputs=csv_model_status
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
csv_predict_btn.click(
|
| 842 |
+
predict_csv,
|
| 843 |
+
inputs=[csv_file_input, csv_model_path],
|
| 844 |
+
outputs=[csv_prediction_output, csv_download]
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
push_hub_btn.click(
|
| 848 |
+
push_to_hub_after_training,
|
| 849 |
+
inputs=[hub_model_path, hub_username, hub_model_name, hub_token],
|
| 850 |
+
outputs=push_hub_output
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
datasets_btn.click(
|
| 854 |
display_available_datasets,
|
| 855 |
+
outputs=datasets_info
|
| 856 |
)
|
| 857 |
|
| 858 |
+
# Run a check for available datasets on app load
|
| 859 |
+
app.load(display_available_datasets, outputs=datasets_info)
|
| 860 |
|
| 861 |
+
# Launch the Gradio app
|
| 862 |
if __name__ == "__main__":
|
| 863 |
app.launch(
|
| 864 |
server_name="0.0.0.0",
|