b2u commited on
Commit
b262013
·
1 Parent(s): 4a96163

making it atrt training

Browse files
Files changed (1) hide show
  1. model.py +65 -31
model.py CHANGED
@@ -142,54 +142,88 @@ class BertClassifier(LabelStudioMLBase):
142
  return predictions
143
 
144
  def fit(self, event_data, data=None, **kwargs):
145
- """Train the model on the labeled data."""
146
- logger.info("=== WEBHOOK DEBUG INFO ===")
147
- logger.info(f"event_data type: {type(event_data)}")
148
- logger.info(f"event_data content: {event_data}")
149
- logger.info(f"data type: {type(data)}")
150
- logger.info(f"data content: {data}")
151
- logger.info(f"kwargs: {kwargs}")
152
- logger.info("=== END WEBHOOK DEBUG INFO ===")
153
-
154
- logger.info(f"Received event: {event_data}")
155
 
156
  try:
157
  if event_data == 'ANNOTATION_CREATED':
158
- # Extract text and label directly from the data
159
  annotation = data.get('annotation', {})
160
  task = data.get('task', {})
161
 
 
 
 
162
  if not task or not annotation:
163
  logger.error("Missing task or annotation data")
164
  return {'status': 'error', 'message': 'Missing task or annotation data'}
165
 
166
- # Get the text from task data
167
  text = task.get('data', {}).get('text', '')
168
-
169
- # Get the label from annotation results
170
  results = annotation.get('result', [])
 
171
  for result in results:
172
  if result.get('type') == 'choices':
173
  label = result.get('value', {}).get('choices', [])[0]
174
- logger.info(f"Processing annotation - Text: {text[:50]}... Label: {label}")
175
 
176
- # Here you would add your training logic
177
- # For now, let's just log it
178
- logger.info(f"Would train model on text: '{text}' with label: '{label}'")
179
-
180
- return {
181
- 'status': 'ok',
182
- 'message': f'Added training data: {text[:50]}... -> {label}'
183
- }
184
-
185
- elif event_data == 'START_TRAINING':
186
- # This event indicates we should start a training cycle
187
- logger.info("Received START_TRAINING event")
188
- # Here you would implement the actual training logic
189
- return {'status': 'ok', 'message': 'Training cycle started'}
190
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  except Exception as e:
192
- logger.error(f"Error during training: {str(e)}")
193
  logger.error("Full error details:", exc_info=True)
194
  return {'status': 'error', 'message': str(e)}
195
 
 
142
  return predictions
143
 
144
  def fit(self, event_data, data=None, **kwargs):
145
+ """Train the model on a single annotation."""
146
+ start_time = datetime.now()
147
+ logger.info(f"=== FIT METHOD CALLED ===")
148
+ logger.info(f"Event data: {event_data}")
149
+ logger.info(f"Data received: {json.dumps(data, indent=2)}")
 
 
 
 
 
150
 
151
  try:
152
  if event_data == 'ANNOTATION_CREATED':
153
+ logger.info("Processing ANNOTATION_CREATED event")
154
  annotation = data.get('annotation', {})
155
  task = data.get('task', {})
156
 
157
+ logger.info(f"Annotation data: {json.dumps(annotation, indent=2)}")
158
+ logger.info(f"Task data: {json.dumps(task, indent=2)}")
159
+
160
  if not task or not annotation:
161
  logger.error("Missing task or annotation data")
162
  return {'status': 'error', 'message': 'Missing task or annotation data'}
163
 
164
+ # Extract text and label
165
  text = task.get('data', {}).get('text', '')
 
 
166
  results = annotation.get('result', [])
167
+
168
  for result in results:
169
  if result.get('type') == 'choices':
170
  label = result.get('value', {}).get('choices', [])[0]
171
+ logger.info(f"Training on - Text: {text[:50]}... Label: {label}")
172
 
173
+ try:
174
+ # Create dataset for single example
175
+ dataset = TextDataset(
176
+ texts=[text],
177
+ labels=[self.categories.index(label)],
178
+ tokenizer=self.tokenizer
179
+ )
180
+ train_loader = DataLoader(dataset, batch_size=1)
181
+
182
+ # Setup training
183
+ optimizer = AdamW(self._model.parameters(), lr=2e-5)
184
+ self._model.train()
185
+
186
+ # Single example training
187
+ for batch in train_loader:
188
+ optimizer.zero_grad()
189
+
190
+ # Move batch to device
191
+ input_ids = batch['input_ids'].to(self.device)
192
+ attention_mask = batch['attention_mask'].to(self.device)
193
+ labels = batch['labels'].to(self.device)
194
+
195
+ # Forward pass
196
+ outputs = self._model(
197
+ input_ids=input_ids,
198
+ attention_mask=attention_mask,
199
+ labels=labels
200
+ )
201
+
202
+ loss = outputs.loss
203
+ logger.info(f"Training loss: {loss.item()}")
204
+
205
+ # Backward pass
206
+ loss.backward()
207
+ optimizer.step()
208
+
209
+ # Save the model after training
210
+ model_path = os.path.join(self.model_dir, 'model_state.pt')
211
+ torch.save(self._model.state_dict(), model_path)
212
+ logger.info(f"Model saved to {model_path}")
213
+
214
+ return {
215
+ 'status': 'ok',
216
+ 'message': f'Successfully trained on: {text[:50]}... -> {label}',
217
+ 'time_taken': str(datetime.now() - start_time)
218
+ }
219
+
220
+ except Exception as e:
221
+ logger.error(f"Training error: {str(e)}")
222
+ logger.error("Full error details:", exc_info=True)
223
+ return {'status': 'error', 'message': f'Training failed: {str(e)}'}
224
+
225
  except Exception as e:
226
+ logger.error(f"Error in fit method: {str(e)}")
227
  logger.error("Full error details:", exc_info=True)
228
  return {'status': 'error', 'message': str(e)}
229