Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,14 +10,15 @@ from pathlib import Path
|
|
| 10 |
from sklearn.metrics import accuracy_score, classification_report
|
| 11 |
from sklearn.model_selection import train_test_split
|
| 12 |
|
| 13 |
-
from huggingface_hub import login
|
| 14 |
from transformers import (
|
| 15 |
AutoTokenizer,
|
| 16 |
BertForSequenceClassification,
|
| 17 |
TrainingArguments,
|
| 18 |
Trainer,
|
| 19 |
DataCollatorWithPadding,
|
| 20 |
-
EarlyStoppingCallback
|
|
|
|
| 21 |
)
|
| 22 |
from datasets import Dataset, DatasetDict
|
| 23 |
|
|
@@ -361,6 +362,42 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 361 |
|
| 362 |
# Data collator
|
| 363 |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
# Create trainer
|
| 366 |
trainer = Trainer(
|
|
@@ -371,36 +408,9 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
|
|
| 371 |
tokenizer=tokenizer,
|
| 372 |
data_collator=data_collator,
|
| 373 |
compute_metrics=compute_metrics,
|
| 374 |
-
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
|
| 375 |
)
|
| 376 |
|
| 377 |
-
TRAINING_LOGS.append("π Starting training...")
|
| 378 |
-
yield "\n".join(TRAINING_LOGS)
|
| 379 |
-
|
| 380 |
-
# Custom training loop with progress updates
|
| 381 |
-
class ProgressCallback:
|
| 382 |
-
def __init__(self, logs_list):
|
| 383 |
-
self.logs = logs_list
|
| 384 |
-
self.step_count = 0
|
| 385 |
-
|
| 386 |
-
def on_step_end(self, args, state, control, model=None, **kwargs):
|
| 387 |
-
self.step_count += 1
|
| 388 |
-
if self.step_count % logging_steps == 0:
|
| 389 |
-
self.logs.append(f"Step {self.step_count}/{total_steps}")
|
| 390 |
-
|
| 391 |
-
def on_epoch_end(self, args, state, control, model=None, **kwargs):
|
| 392 |
-
epoch = int(state.epoch)
|
| 393 |
-
self.logs.append(f"β
Epoch {epoch} completed")
|
| 394 |
-
|
| 395 |
-
def on_evaluate(self, args, state, control, model=None, logs=None, **kwargs):
|
| 396 |
-
if logs:
|
| 397 |
-
acc = logs.get('eval_accuracy', 0)
|
| 398 |
-
loss = logs.get('eval_loss', 0)
|
| 399 |
-
self.logs.append(f"π Eval - Accuracy: {acc:.4f}, Loss: {loss:.4f}")
|
| 400 |
-
|
| 401 |
-
progress_callback = ProgressCallback(TRAINING_LOGS)
|
| 402 |
-
trainer.add_callback(progress_callback)
|
| 403 |
-
|
| 404 |
# Train the model
|
| 405 |
try:
|
| 406 |
trainer.train()
|
|
|
|
| 10 |
from sklearn.metrics import accuracy_score, classification_report
|
| 11 |
from sklearn.model_selection import train_test_split
|
| 12 |
|
| 13 |
+
from huggingface_hub import login
|
| 14 |
from transformers import (
|
| 15 |
AutoTokenizer,
|
| 16 |
BertForSequenceClassification,
|
| 17 |
TrainingArguments,
|
| 18 |
Trainer,
|
| 19 |
DataCollatorWithPadding,
|
| 20 |
+
EarlyStoppingCallback,
|
| 21 |
+
TrainerCallback
|
| 22 |
)
|
| 23 |
from datasets import Dataset, DatasetDict
|
| 24 |
|
|
|
|
| 362 |
|
| 363 |
# Data collator
|
| 364 |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 365 |
+
|
| 366 |
+
# Corrected Callback Class
|
| 367 |
+
class ProgressCallback(TrainerCallback):
|
| 368 |
+
def __init__(self, logs_list, total_steps):
|
| 369 |
+
self.logs = logs_list
|
| 370 |
+
self.total_steps = total_steps
|
| 371 |
+
|
| 372 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 373 |
+
self.logs.append("π Starting training...")
|
| 374 |
+
self.log_update()
|
| 375 |
+
|
| 376 |
+
def on_step_end(self, args, state, control, **kwargs):
|
| 377 |
+
if state.global_step % args.logging_steps == 0:
|
| 378 |
+
self.logs.append(f"Step {state.global_step}/{self.total_steps}")
|
| 379 |
+
self.log_update()
|
| 380 |
+
|
| 381 |
+
def on_epoch_end(self, args, state, control, **kwargs):
|
| 382 |
+
epoch = int(state.epoch)
|
| 383 |
+
self.logs.append(f"β
Epoch {epoch} completed")
|
| 384 |
+
self.log_update()
|
| 385 |
+
|
| 386 |
+
def on_evaluate(self, args, state, control, logs=None, **kwargs):
|
| 387 |
+
if logs:
|
| 388 |
+
acc = logs.get('eval_accuracy', 0)
|
| 389 |
+
loss = logs.get('eval_loss', 0)
|
| 390 |
+
self.logs.append(f"π Eval - Accuracy: {acc:.4f}, Loss: {loss:.4f}")
|
| 391 |
+
self.log_update()
|
| 392 |
+
|
| 393 |
+
def log_update(self):
|
| 394 |
+
# This is a custom helper to yield updates to the Gradio UI
|
| 395 |
+
# The original code did this manually, but with TrainerCallback,
|
| 396 |
+
# we can't do that. So we log to the list and rely on the UI
|
| 397 |
+
# to refresh. For a real-time stream, this part would need to be
|
| 398 |
+
# handled by Gradio's streaming feature, but this approach
|
| 399 |
+
# is sufficient for the user's current setup.
|
| 400 |
+
pass
|
| 401 |
|
| 402 |
# Create trainer
|
| 403 |
trainer = Trainer(
|
|
|
|
| 408 |
tokenizer=tokenizer,
|
| 409 |
data_collator=data_collator,
|
| 410 |
compute_metrics=compute_metrics,
|
| 411 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=3), ProgressCallback(TRAINING_LOGS, total_steps)]
|
| 412 |
)
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
# Train the model
|
| 415 |
try:
|
| 416 |
trainer.train()
|