msmaje commited on
Commit
1a06556
Β·
verified Β·
1 Parent(s): bdc1139

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -30
app.py CHANGED
@@ -10,14 +10,15 @@ from pathlib import Path
10
  from sklearn.metrics import accuracy_score, classification_report
11
  from sklearn.model_selection import train_test_split
12
 
13
- from huggingface_hub import login, HfApi
14
  from transformers import (
15
  AutoTokenizer,
16
  BertForSequenceClassification,
17
  TrainingArguments,
18
  Trainer,
19
  DataCollatorWithPadding,
20
- EarlyStoppingCallback
 
21
  )
22
  from datasets import Dataset, DatasetDict
23
 
@@ -361,6 +362,42 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
361
 
362
  # Data collator
363
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  # Create trainer
366
  trainer = Trainer(
@@ -371,36 +408,9 @@ def train_model_inline(uploaded_file, text_column, label_column, num_epochs, bat
371
  tokenizer=tokenizer,
372
  data_collator=data_collator,
373
  compute_metrics=compute_metrics,
374
- callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
375
  )
376
 
377
- TRAINING_LOGS.append("πŸš€ Starting training...")
378
- yield "\n".join(TRAINING_LOGS)
379
-
380
- # Custom training loop with progress updates
381
- class ProgressCallback:
382
- def __init__(self, logs_list):
383
- self.logs = logs_list
384
- self.step_count = 0
385
-
386
- def on_step_end(self, args, state, control, model=None, **kwargs):
387
- self.step_count += 1
388
- if self.step_count % logging_steps == 0:
389
- self.logs.append(f"Step {self.step_count}/{total_steps}")
390
-
391
- def on_epoch_end(self, args, state, control, model=None, **kwargs):
392
- epoch = int(state.epoch)
393
- self.logs.append(f"βœ… Epoch {epoch} completed")
394
-
395
- def on_evaluate(self, args, state, control, model=None, logs=None, **kwargs):
396
- if logs:
397
- acc = logs.get('eval_accuracy', 0)
398
- loss = logs.get('eval_loss', 0)
399
- self.logs.append(f"πŸ“Š Eval - Accuracy: {acc:.4f}, Loss: {loss:.4f}")
400
-
401
- progress_callback = ProgressCallback(TRAINING_LOGS)
402
- trainer.add_callback(progress_callback)
403
-
404
  # Train the model
405
  try:
406
  trainer.train()
 
10
  from sklearn.metrics import accuracy_score, classification_report
11
  from sklearn.model_selection import train_test_split
12
 
13
+ from huggingface_hub import login
14
  from transformers import (
15
  AutoTokenizer,
16
  BertForSequenceClassification,
17
  TrainingArguments,
18
  Trainer,
19
  DataCollatorWithPadding,
20
+ EarlyStoppingCallback,
21
+ TrainerCallback
22
  )
23
  from datasets import Dataset, DatasetDict
24
 
 
362
 
363
  # Data collator
364
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
365
+
366
+ # Corrected Callback Class
367
+ class ProgressCallback(TrainerCallback):
368
+ def __init__(self, logs_list, total_steps):
369
+ self.logs = logs_list
370
+ self.total_steps = total_steps
371
+
372
+ def on_train_begin(self, args, state, control, **kwargs):
373
+ self.logs.append("πŸš€ Starting training...")
374
+ self.log_update()
375
+
376
+ def on_step_end(self, args, state, control, **kwargs):
377
+ if state.global_step % args.logging_steps == 0:
378
+ self.logs.append(f"Step {state.global_step}/{self.total_steps}")
379
+ self.log_update()
380
+
381
+ def on_epoch_end(self, args, state, control, **kwargs):
382
+ epoch = int(state.epoch)
383
+ self.logs.append(f"βœ… Epoch {epoch} completed")
384
+ self.log_update()
385
+
386
+ def on_evaluate(self, args, state, control, logs=None, **kwargs):
387
+ if logs:
388
+ acc = logs.get('eval_accuracy', 0)
389
+ loss = logs.get('eval_loss', 0)
390
+ self.logs.append(f"πŸ“Š Eval - Accuracy: {acc:.4f}, Loss: {loss:.4f}")
391
+ self.log_update()
392
+
393
+ def log_update(self):
394
+ # This is a custom helper to yield updates to the Gradio UI
395
+ # The original code did this manually, but with TrainerCallback,
396
+ # we can't do that. So we log to the list and rely on the UI
397
+ # to refresh. For a real-time stream, this part would need to be
398
+ # handled by Gradio's streaming feature, but this approach
399
+ # is sufficient for the user's current setup.
400
+ pass
401
 
402
  # Create trainer
403
  trainer = Trainer(
 
408
  tokenizer=tokenizer,
409
  data_collator=data_collator,
410
  compute_metrics=compute_metrics,
411
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3), ProgressCallback(TRAINING_LOGS, total_steps)]
412
  )
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  # Train the model
415
  try:
416
  trainer.train()