b2u commited on
Commit
62c0df2
·
1 Parent(s): 41785d5

Instead of trying to parse the completions string, we use the Label Studio interface to get tasks directly

Browse files
Files changed (1) hide show
  1. model.py +24 -21
model.py CHANGED
@@ -210,12 +210,19 @@ class BertClassifier(LabelStudioMLBase):
210
  # Extract training data
211
  texts, labels = [], []
212
 
213
- # Process completions directly from Label Studio
214
  try:
215
- for task in completions:
 
 
 
 
 
 
 
216
  try:
217
  # Get text from task
218
- text = task.get('data', {}).get('text', '')
219
  if not text:
220
  continue
221
 
@@ -231,26 +238,26 @@ class BertClassifier(LabelStudioMLBase):
231
  if not result:
232
  continue
233
 
234
- choices = result[0].get('value', {}).get('choices', [])
235
- if not choices:
236
- continue
237
-
238
- label = choices[0]
239
-
240
- logger.info(f"Successfully extracted: Text='{text}', Label='{label}'")
241
- texts.append(text)
242
- labels.append(label)
243
-
244
  except Exception as e:
245
  logger.error(f"Error processing annotation: {str(e)}")
246
  continue
247
-
248
  except Exception as e:
249
  logger.error(f"Error processing task: {str(e)}")
250
  continue
251
 
252
  except Exception as e:
253
- logger.error(f"Error processing completions: {str(e)}")
254
 
255
  logger.info(f"Prepared {len(texts)} examples for training")
256
 
@@ -280,13 +287,9 @@ class BertClassifier(LabelStudioMLBase):
280
 
281
  tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
282
 
283
- # Define output directory
284
- output_dir = os.path.join(self.model_dir, "results")
285
- os.makedirs(output_dir, exist_ok=True)
286
-
287
  # Define training arguments
288
  training_args = TrainingArguments(
289
- output_dir=output_dir,
290
  num_train_epochs=3,
291
  per_device_train_batch_size=8,
292
  per_device_eval_batch_size=8,
@@ -331,7 +334,7 @@ class BertClassifier(LabelStudioMLBase):
331
 
332
  except Exception as e:
333
  logger.error(f"Training failed: {str(e)}")
334
- logger.error('Full error details:', exc_info=True)
335
  return {
336
  'status': 'error',
337
  'error': str(e),
 
210
  # Extract training data
211
  texts, labels = [], []
212
 
213
+ # Get annotations from Label Studio
214
  try:
215
+ # Get interface info
216
+ from_name, to_name, value = self.label_interface.get_first_tag_occurence('Choices', 'Text')
217
+
218
+ # Get tasks from Label Studio
219
+ tasks = self.label_interface.get_tasks()
220
+ logger.info(f"Found {len(tasks)} tasks")
221
+
222
+ for task in tasks:
223
  try:
224
  # Get text from task
225
+ text = task.get('data', {}).get(value)
226
  if not text:
227
  continue
228
 
 
238
  if not result:
239
  continue
240
 
241
+ for r in result:
242
+ if r.get('from_name') == from_name and r.get('to_name') == to_name:
243
+ choices = r.get('value', {}).get('choices', [])
244
+ if choices:
245
+ label = choices[0]
246
+ logger.info(f"Successfully extracted: Text='{text}', Label='{label}'")
247
+ texts.append(text)
248
+ labels.append(label)
249
+ break
250
+
251
  except Exception as e:
252
  logger.error(f"Error processing annotation: {str(e)}")
253
  continue
254
+
255
  except Exception as e:
256
  logger.error(f"Error processing task: {str(e)}")
257
  continue
258
 
259
  except Exception as e:
260
+ logger.error(f"Error getting tasks: {str(e)}")
261
 
262
  logger.info(f"Prepared {len(texts)} examples for training")
263
 
 
287
 
288
  tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
289
 
 
 
 
 
290
  # Define training arguments
291
  training_args = TrainingArguments(
292
+ output_dir=os.path.join(self.model_dir, "results"),
293
  num_train_epochs=3,
294
  per_device_train_batch_size=8,
295
  per_device_eval_batch_size=8,
 
334
 
335
  except Exception as e:
336
  logger.error(f"Training failed: {str(e)}")
337
+ logger.error("Full error details:", exc_info=True)
338
  return {
339
  'status': 'error',
340
  'error': str(e),