b2u commited on
Commit
d44186e
·
1 Parent(s): 743d8c8

Adding training initiation and logging

Browse files
Files changed (1) hide show
  1. model.py +53 -16
model.py CHANGED
@@ -7,6 +7,8 @@ import json
7
  import torch
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from label_studio_ml.model import LabelStudioMLBase, ModelResponse
 
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
@@ -180,26 +182,61 @@ Category:"""
180
  return predictions
181
 
182
  def fit(self, event, data, **kwargs):
183
- """Handle annotation events from Label Studio
184
-
185
- Args:
186
- event (str): Event type ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING')
187
- data (dict): Event payload with annotation details
188
- """
189
 
190
  valid_events = {'ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'}
191
  if event not in valid_events:
192
  logger.warning(f"Skip training: event {event} is not supported")
193
  return
194
 
195
- # Extract text and label
196
- text = data['task']['data']['text']
197
- label = data['annotation']['result'][0]['value']['choices'][0]
198
-
199
- training_data = {
200
- 'text': text,
201
- 'label': label
202
- }
203
-
204
- logger.info(f"Extracted training data: {json.dumps(training_data, indent=2)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
 
7
  import torch
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from label_studio_ml.model import LabelStudioMLBase, ModelResponse
10
+ from peft import get_peft_model, LoraConfig
11
+ import time
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
182
  return predictions
183
 
184
  def fit(self, event, data, **kwargs):
185
+ """Handle annotation events from Label Studio"""
186
+ start_time = time.time()
187
+ logger.info("Starting training session...")
 
 
 
188
 
189
  valid_events = {'ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'}
190
  if event not in valid_events:
191
  logger.warning(f"Skip training: event {event} is not supported")
192
  return
193
 
194
+ try:
195
+ # Extract text and label
196
+ text = data['task']['data']['text']
197
+ label = data['annotation']['result'][0]['value']['choices'][0]
198
+
199
+ # Configure LoRA
200
+ lora_config = LoraConfig(
201
+ r=int(os.getenv('LORA_R', '8')),
202
+ lora_alpha=int(os.getenv('LORA_ALPHA', '32')),
203
+ target_modules=os.getenv('LORA_TARGET_MODULES', 'q,v').split(','),
204
+ lora_dropout=float(os.getenv('LORA_DROPOUT', '0.1')),
205
+ bias="none",
206
+ task_type="SEQ_2_SEQ_LM"
207
+ )
208
+
209
+ logger.info("Preparing model for training...")
210
+ model = get_peft_model(self.model, lora_config)
211
+ model.print_trainable_parameters()
212
+
213
+ # Training loop
214
+ logger.info("Starting training loop...")
215
+ optimizer = torch.optim.AdamW(model.parameters(), lr=float(os.getenv('LEARNING_RATE', '1e-4')))
216
+
217
+ # Single training step for this annotation
218
+ model.train()
219
+ optimizer.zero_grad()
220
+
221
+ inputs = self.tokenizer(text, return_tensors="pt", max_length=self.max_length, truncation=True).to(self.device)
222
+ labels = self.tokenizer(label, return_tensors="pt", max_length=self.generation_max_length, truncation=True).to(self.device)
223
+
224
+ outputs = model(**inputs, labels=labels["input_ids"])
225
+ loss = outputs.loss
226
+ loss.backward()
227
+ optimizer.step()
228
+
229
+ logger.info(f"Training step completed. Loss: {loss.item():.4f}")
230
+
231
+ # Switch back to eval mode
232
+ model.eval()
233
+
234
+ training_time = time.time() - start_time
235
+ logger.info(f"Training session completed successfully in {training_time:.2f} seconds with tag: '{text}' and label: '{label}'")
236
+
237
+ except Exception as e:
238
+ training_time = time.time() - start_time
239
+ logger.error(f"Training failed after {training_time:.2f} seconds")
240
+ logger.error(f"Error during training: {str(e)}")
241
+ raise
242