Spaces:
Sleeping
Sleeping
Instead of trying to parse the completions string, we use the Label Studio interface to get tasks directly
Browse files
model.py
CHANGED
|
@@ -210,12 +210,19 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 210 |
# Extract training data
|
| 211 |
texts, labels = [], []
|
| 212 |
|
| 213 |
-
#
|
| 214 |
try:
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
try:
|
| 217 |
# Get text from task
|
| 218 |
-
text = task.get('data', {}).get(
|
| 219 |
if not text:
|
| 220 |
continue
|
| 221 |
|
|
@@ -231,26 +238,26 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 231 |
if not result:
|
| 232 |
continue
|
| 233 |
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
except Exception as e:
|
| 245 |
logger.error(f"Error processing annotation: {str(e)}")
|
| 246 |
continue
|
| 247 |
-
|
| 248 |
except Exception as e:
|
| 249 |
logger.error(f"Error processing task: {str(e)}")
|
| 250 |
continue
|
| 251 |
|
| 252 |
except Exception as e:
|
| 253 |
-
logger.error(f"Error
|
| 254 |
|
| 255 |
logger.info(f"Prepared {len(texts)} examples for training")
|
| 256 |
|
|
@@ -280,13 +287,9 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 280 |
|
| 281 |
tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
|
| 282 |
|
| 283 |
-
# Define output directory
|
| 284 |
-
output_dir = os.path.join(self.model_dir, "results")
|
| 285 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 286 |
-
|
| 287 |
# Define training arguments
|
| 288 |
training_args = TrainingArguments(
|
| 289 |
-
output_dir=
|
| 290 |
num_train_epochs=3,
|
| 291 |
per_device_train_batch_size=8,
|
| 292 |
per_device_eval_batch_size=8,
|
|
@@ -331,7 +334,7 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 331 |
|
| 332 |
except Exception as e:
|
| 333 |
logger.error(f"Training failed: {str(e)}")
|
| 334 |
-
logger.error(
|
| 335 |
return {
|
| 336 |
'status': 'error',
|
| 337 |
'error': str(e),
|
|
|
|
| 210 |
# Extract training data
|
| 211 |
texts, labels = [], []
|
| 212 |
|
| 213 |
+
# Get annotations from Label Studio
|
| 214 |
try:
|
| 215 |
+
# Get interface info
|
| 216 |
+
from_name, to_name, value = self.label_interface.get_first_tag_occurence('Choices', 'Text')
|
| 217 |
+
|
| 218 |
+
# Get tasks from Label Studio
|
| 219 |
+
tasks = self.label_interface.get_tasks()
|
| 220 |
+
logger.info(f"Found {len(tasks)} tasks")
|
| 221 |
+
|
| 222 |
+
for task in tasks:
|
| 223 |
try:
|
| 224 |
# Get text from task
|
| 225 |
+
text = task.get('data', {}).get(value)
|
| 226 |
if not text:
|
| 227 |
continue
|
| 228 |
|
|
|
|
| 238 |
if not result:
|
| 239 |
continue
|
| 240 |
|
| 241 |
+
for r in result:
|
| 242 |
+
if r.get('from_name') == from_name and r.get('to_name') == to_name:
|
| 243 |
+
choices = r.get('value', {}).get('choices', [])
|
| 244 |
+
if choices:
|
| 245 |
+
label = choices[0]
|
| 246 |
+
logger.info(f"Successfully extracted: Text='{text}', Label='{label}'")
|
| 247 |
+
texts.append(text)
|
| 248 |
+
labels.append(label)
|
| 249 |
+
break
|
| 250 |
+
|
| 251 |
except Exception as e:
|
| 252 |
logger.error(f"Error processing annotation: {str(e)}")
|
| 253 |
continue
|
| 254 |
+
|
| 255 |
except Exception as e:
|
| 256 |
logger.error(f"Error processing task: {str(e)}")
|
| 257 |
continue
|
| 258 |
|
| 259 |
except Exception as e:
|
| 260 |
+
logger.error(f"Error getting tasks: {str(e)}")
|
| 261 |
|
| 262 |
logger.info(f"Prepared {len(texts)} examples for training")
|
| 263 |
|
|
|
|
| 287 |
|
| 288 |
tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
# Define training arguments
|
| 291 |
training_args = TrainingArguments(
|
| 292 |
+
output_dir=os.path.join(self.model_dir, "results"),
|
| 293 |
num_train_epochs=3,
|
| 294 |
per_device_train_batch_size=8,
|
| 295 |
per_device_eval_batch_size=8,
|
|
|
|
| 334 |
|
| 335 |
except Exception as e:
|
| 336 |
logger.error(f"Training failed: {str(e)}")
|
| 337 |
+
logger.error("Full error details:", exc_info=True)
|
| 338 |
return {
|
| 339 |
'status': 'error',
|
| 340 |
'error': str(e),
|