b2u commited on
Commit
1589415
Β·
1 Parent(s): 1331b4b

adding proper training cycle handling

Browse files
Files changed (1) hide show
  1. model.py +42 -54
model.py CHANGED
@@ -17,13 +17,6 @@ current_dir = Path(__file__).parent
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
- # Add these debug lines
21
- logger.info("=== DEBUG INFO ===")
22
- logger.info(f"Python path: {sys.path}")
23
- logger.info(f"Current directory: {os.getcwd()}")
24
- logger.info(f"Directory contents: {os.listdir('.')}")
25
- logger.info("=== END DEBUG INFO ===")
26
-
27
  # Move TextDataset class here
28
  class TextDataset(Dataset):
29
  def __init__(self, texts, labels, tokenizer, max_length=128):
@@ -43,10 +36,10 @@ class BertClassifier(LabelStudioMLBase):
43
  super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
44
 
45
  # Load training configuration from environment variables
46
- self.learning_rate = float(os.getenv('LEARNING_RATE', 2e-5))
47
- self.num_train_epochs = int(os.getenv('NUM_TRAIN_EPOCHS', 2))
48
- self.weight_decay = float(os.getenv('WEIGHT_DECAY', 0.01))
49
- self.start_training_threshold = int(os.getenv('START_TRAINING_EACH_N_UPDATES', 1))
50
 
51
  logger.info("=== Training Configuration ===")
52
  logger.info(f"βœ“ Learning rate: {self.learning_rate}")
@@ -55,38 +48,28 @@ class BertClassifier(LabelStudioMLBase):
55
  logger.info(f"βœ“ Training threshold: {self.start_training_threshold}")
56
  logger.info("============================")
57
 
58
- logger.info(f"Initializing BertClassifier with project_id: {project_id}")
59
- logger.info(f"Label config length: {len(label_config) if label_config else 0}")
60
-
61
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
62
  logger.info(f"Using device: {self.device}")
63
 
64
- # Define categories
65
- self.categories = [
66
- 'affiliate_classification', 'brand', 'business_and_career',
67
- 'content_quality', 'date', 'demographic', 'event',
68
- 'faith_and_religion', 'gaming', 'health',
69
- 'internal_categorization', 'location', 'number',
70
- 'performance', 'post_type', 'pricing_tier',
71
- 'product', 'profession', 'pii', 'social_network',
72
- 'style_and_fashion', 'no_category'
73
- ]
74
 
75
- self.model_dir = os.path.join(os.path.dirname(__file__), 'model')
76
- os.makedirs(self.model_dir, exist_ok=True)
 
 
 
 
 
 
77
 
78
- # Initialize model and tokenizer
79
- try:
80
- self._model = AutoModelForSequenceClassification.from_pretrained(
81
- 'bert-base-uncased',
82
- num_labels=len(self.categories)
83
- )
84
- self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
85
- self._model.to(self.device)
86
- logger.info("Successfully loaded BERT model and tokenizer")
87
- except Exception as e:
88
- logger.error(f"Error loading model: {str(e)}")
89
- logger.error("Full error details:", exc_info=True)
90
 
91
  def predict(self, tasks, **kwargs):
92
  """Generate predictions for a list of tasks."""
@@ -150,7 +133,6 @@ class BertClassifier(LabelStudioMLBase):
150
  return predictions
151
 
152
  def fit(self, event_data, data=None, **kwargs):
153
- """Train the model on a single annotation."""
154
  start_time = datetime.now()
155
  logger.info("=== FIT METHOD CALLED ===")
156
 
@@ -196,11 +178,12 @@ class BertClassifier(LabelStudioMLBase):
196
  self._model.train()
197
  logger.info("Starting training...")
198
 
199
- # Multi-epoch training
 
200
  for epoch in range(self.num_train_epochs):
201
  logger.info(f"Starting epoch {epoch + 1}/{self.num_train_epochs}")
 
202
 
203
- # Single example training
204
  for batch in train_loader:
205
  optimizer.zero_grad()
206
 
@@ -217,27 +200,32 @@ class BertClassifier(LabelStudioMLBase):
217
  )
218
 
219
  loss = outputs.loss
220
- logger.info(f"Training loss: {loss.item():.4f}")
221
 
222
  # Backward pass
223
  loss.backward()
224
  optimizer.step()
225
 
226
- # Save the model
227
- model_path = os.path.join(self.model_dir, 'model_state.pt')
228
- torch.save(self._model.state_dict(), model_path)
229
- logger.info(f"βœ“ Model saved to {model_path}")
230
-
231
- return {
232
- 'status': 'ok',
233
- 'message': f'Successfully trained on: {text[:50]}... -> {label}',
234
- 'time_taken': str(datetime.now() - start_time)
235
- }
 
 
 
 
 
 
236
 
237
  except Exception as e:
238
  logger.error(f"Training error: {str(e)}")
239
- logger.error("Full error details:", exc_info=True)
240
- return {'status': 'error', 'message': f'Training failed: {str(e)}'}
241
 
242
  except Exception as e:
243
  logger.error(f"Error in fit method: {str(e)}")
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
 
 
 
 
20
  # Move TextDataset class here
21
  class TextDataset(Dataset):
22
  def __init__(self, texts, labels, tokenizer, max_length=128):
 
36
  super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
37
 
38
  # Load training configuration from environment variables
39
+ self.learning_rate = float(os.getenv('LEARNING_RATE'))
40
+ self.num_train_epochs = int(os.getenv('NUM_TRAIN_EPOCHS'))
41
+ self.weight_decay = float(os.getenv('WEIGHT_DECAY'))
42
+ self.start_training_threshold = int(os.getenv('START_TRAINING_EACH_N_UPDATES'))
43
 
44
  logger.info("=== Training Configuration ===")
45
  logger.info(f"βœ“ Learning rate: {self.learning_rate}")
 
48
  logger.info(f"βœ“ Training threshold: {self.start_training_threshold}")
49
  logger.info("============================")
50
 
51
+ # Initialize model and move to device
 
 
52
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
  logger.info(f"Using device: {self.device}")
54
 
55
+ # Initialize model
56
+ self._model = AutoModelForSequenceClassification.from_pretrained(
57
+ 'bert-base-uncased',
58
+ num_labels=len(self.categories)
59
+ )
60
+ self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
 
 
 
 
61
 
62
+ # Load saved model if exists
63
+ model_path = os.path.join(self.model_dir, 'model_state.pt')
64
+ if os.path.exists(model_path):
65
+ try:
66
+ self._model.load_state_dict(torch.load(model_path))
67
+ logger.info(f"βœ“ Loaded saved model from {model_path}")
68
+ except Exception as e:
69
+ logger.error(f"Failed to load model: {str(e)}")
70
 
71
+ self._model.to(self.device)
72
+ logger.info("βœ“ Model ready")
 
 
 
 
 
 
 
 
 
 
73
 
74
  def predict(self, tasks, **kwargs):
75
  """Generate predictions for a list of tasks."""
 
133
  return predictions
134
 
135
  def fit(self, event_data, data=None, **kwargs):
 
136
  start_time = datetime.now()
137
  logger.info("=== FIT METHOD CALLED ===")
138
 
 
178
  self._model.train()
179
  logger.info("Starting training...")
180
 
181
+ # Training loop
182
+ total_loss = 0
183
  for epoch in range(self.num_train_epochs):
184
  logger.info(f"Starting epoch {epoch + 1}/{self.num_train_epochs}")
185
+ epoch_loss = 0
186
 
 
187
  for batch in train_loader:
188
  optimizer.zero_grad()
189
 
 
200
  )
201
 
202
  loss = outputs.loss
203
+ epoch_loss += loss.item()
204
 
205
  # Backward pass
206
  loss.backward()
207
  optimizer.step()
208
 
209
+ avg_epoch_loss = epoch_loss / len(train_loader)
210
+ total_loss += avg_epoch_loss
211
+ logger.info(f"Epoch {epoch + 1} loss: {avg_epoch_loss:.4f}")
212
+
213
+ avg_training_loss = total_loss / self.num_train_epochs
214
+ logger.info(f"Average training loss: {avg_training_loss:.4f}")
215
+
216
+ # Save model
217
+ model_path = os.path.join(self.model_dir, 'model_state.pt')
218
+ torch.save(self._model.state_dict(), model_path)
219
+ logger.info(f"βœ“ Model saved to {model_path}")
220
+
221
+ return {
222
+ 'status': 'ok',
223
+ 'message': f'Training completed with avg loss: {avg_training_loss:.4f}'
224
+ }
225
 
226
  except Exception as e:
227
  logger.error(f"Training error: {str(e)}")
228
+ return {'status': 'error', 'message': str(e)}
 
229
 
230
  except Exception as e:
231
  logger.error(f"Error in fit method: {str(e)}")