b2u commited on
Commit
c6dedc8
·
1 Parent(s): da5d943

keep debugging

Browse files
Files changed (1) hide show
  1. model.py +98 -52
model.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import logging
4
  import pathlib
5
  import pickle
 
6
  from typing import List, Dict, Optional
7
  from label_studio_ml.model import LabelStudioMLBase
8
  from transformers import (
@@ -70,10 +71,11 @@ class BertClassifier(LabelStudioMLBase):
70
  _model = None
71
 
72
  def __init__(self, project_id=None, label_config=None, **kwargs):
73
- # Initialize parent class properly
74
  super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
75
 
76
- # Your existing initialization code
 
 
77
  self.label_encoder = LabelEncoder()
78
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
  self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
@@ -82,9 +84,6 @@ class BertClassifier(LabelStudioMLBase):
82
  self.model_dir = os.path.join(os.path.dirname(__file__), 'model')
83
  os.makedirs(self.model_dir, exist_ok=True)
84
 
85
- # Skip Label Studio client initialization
86
- self.label_studio_client = None
87
-
88
  # Define your categories
89
  self.categories = [
90
  'affiliate_classification', 'brand', 'business_and_career',
@@ -196,6 +195,30 @@ class BertClassifier(LabelStudioMLBase):
196
  logger.error("Full error details:", exc_info=True)
197
  raise
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def fit(self, completions, workdir=None, **kwargs):
200
  """Train model on labeled data"""
201
  logger.info('Starting model training...')
@@ -210,57 +233,57 @@ class BertClassifier(LabelStudioMLBase):
210
  # Extract training data
211
  texts, labels = [], []
212
 
213
- try:
214
- # Get interface info
215
- from_name, to_name, value = self.label_interface.get_first_tag_occurence('Choices', 'Text')
216
-
217
- # Get tasks from Label Studio
218
  tasks = self.get_tasks()
219
- logger.info(f"Found {len(tasks)} tasks")
220
-
221
- for task in tasks:
222
- try:
223
- # Get text from task
224
- text = task['data'].get(value)
225
- if not text:
226
- logger.warning(f"No text found in task {task.get('id')}")
227
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
- # Get annotations
230
- annotations = task.get('annotations', [])
231
- if not annotations:
232
- logger.warning(f"No annotations found for task {task.get('id')}")
233
  continue
234
 
235
- for annotation in annotations:
236
- try:
237
- # Get choices from result
238
- results = annotation.get('result', [])
239
- if not results:
240
- logger.warning(f"No results found in annotation for task {task.get('id')}")
241
- continue
242
-
243
- for result in results:
244
- if result.get('from_name') == from_name and result.get('to_name') == to_name:
245
- choices = result.get('value', {}).get('choices', [])
246
- if choices:
247
- label = choices[0]
248
- logger.info(f"Successfully extracted: Text='{text[:50]}...', Label='{label}'")
249
- texts.append(text)
250
- labels.append(label)
251
- break
252
-
253
- except Exception as e:
254
- logger.error(f"Error processing annotation: {str(e)}")
255
- continue
256
-
257
- except Exception as e:
258
- logger.error(f"Error processing task: {str(e)}")
259
- continue
260
-
261
- except Exception as e:
262
- logger.error(f"Error getting tasks: {str(e)}")
263
- logger.error("Full error details:", exc_info=True)
264
 
265
  logger.info(f"Prepared {len(texts)} examples for training")
266
 
@@ -343,3 +366,26 @@ class BertClassifier(LabelStudioMLBase):
343
  'error': str(e),
344
  'train_size': len(texts) if 'texts' in locals() else 0
345
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import logging
4
  import pathlib
5
  import pickle
6
+ import json
7
  from typing import List, Dict, Optional
8
  from label_studio_ml.model import LabelStudioMLBase
9
  from transformers import (
 
71
  _model = None
72
 
73
  def __init__(self, project_id=None, label_config=None, **kwargs):
 
74
  super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
75
 
76
+ logger.info(f"Initializing BertClassifier with project_id: {project_id}")
77
+ logger.info(f"Label config: {label_config}")
78
+
79
  self.label_encoder = LabelEncoder()
80
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
  self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
 
84
  self.model_dir = os.path.join(os.path.dirname(__file__), 'model')
85
  os.makedirs(self.model_dir, exist_ok=True)
86
 
 
 
 
87
  # Define your categories
88
  self.categories = [
89
  'affiliate_classification', 'brand', 'business_and_career',
 
195
  logger.error("Full error details:", exc_info=True)
196
  raise
197
 
198
+ def get_tasks(self):
199
+ """Get tasks from completions"""
200
+ try:
201
+ from_name, to_name, value = self.label_interface.get_first_tag_occurence('Choices', 'Text')
202
+
203
+ # Get all tasks from Label Studio ML backend storage
204
+ tasks = []
205
+
206
+ # Try to get tasks from Label Studio ML storage
207
+ storage_path = os.path.join(self.model_dir, 'tasks.json')
208
+ if os.path.exists(storage_path):
209
+ try:
210
+ with open(storage_path, 'r') as f:
211
+ tasks = json.load(f)
212
+ logger.info(f"Loaded {len(tasks)} tasks from storage")
213
+ except Exception as e:
214
+ logger.error(f"Error loading tasks from storage: {str(e)}")
215
+
216
+ return tasks
217
+
218
+ except Exception as e:
219
+ logger.error(f"Error in get_tasks: {str(e)}")
220
+ return []
221
+
222
  def fit(self, completions, workdir=None, **kwargs):
223
  """Train model on labeled data"""
224
  logger.info('Starting model training...')
 
233
  # Extract training data
234
  texts, labels = [], []
235
 
236
+ # If completions is a string (like "START_TRAINING"), try to get tasks
237
+ if isinstance(completions, str):
 
 
 
238
  tasks = self.get_tasks()
239
+ else:
240
+ # If completions is a list, use it directly
241
+ tasks = completions if isinstance(completions, list) else [completions]
242
+
243
+ logger.info(f"Processing {len(tasks)} tasks")
244
+
245
+ # Get interface info
246
+ from_name, to_name, value = self.label_interface.get_first_tag_occurence('Choices', 'Text')
247
+
248
+ for task in tasks:
249
+ try:
250
+ # Get text from task
251
+ text = task['data'].get(value) if isinstance(task, dict) else None
252
+ if not text:
253
+ logger.warning(f"No text found in task")
254
+ continue
255
+
256
+ # Get annotations
257
+ annotations = task.get('annotations', []) if isinstance(task, dict) else []
258
+ if not annotations:
259
+ logger.warning(f"No annotations found for task")
260
+ continue
261
+
262
+ for annotation in annotations:
263
+ try:
264
+ # Get choices from result
265
+ results = annotation.get('result', [])
266
+ if not results:
267
+ logger.warning(f"No results found in annotation")
268
+ continue
269
+
270
+ for result in results:
271
+ if result.get('from_name') == from_name and result.get('to_name') == to_name:
272
+ choices = result.get('value', {}).get('choices', [])
273
+ if choices:
274
+ label = choices[0]
275
+ logger.info(f"Successfully extracted: Text='{text[:50]}...', Label='{label}'")
276
+ texts.append(text)
277
+ labels.append(label)
278
+ break
279
 
280
+ except Exception as e:
281
+ logger.error(f"Error processing annotation: {str(e)}")
 
 
282
  continue
283
 
284
+ except Exception as e:
285
+ logger.error(f"Error processing task: {str(e)}")
286
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  logger.info(f"Prepared {len(texts)} examples for training")
289
 
 
366
  'error': str(e),
367
  'train_size': len(texts) if 'texts' in locals() else 0
368
  }
369
+
370
+ def save_task(self, task):
371
+ """Save a task to local storage"""
372
+ try:
373
+ storage_path = os.path.join(self.model_dir, 'tasks.json')
374
+ tasks = []
375
+
376
+ # Load existing tasks
377
+ if os.path.exists(storage_path):
378
+ with open(storage_path, 'r') as f:
379
+ tasks = json.load(f)
380
+
381
+ # Add new task
382
+ tasks.append(task)
383
+
384
+ # Save tasks
385
+ with open(storage_path, 'w') as f:
386
+ json.dump(tasks, f)
387
+
388
+ logger.info(f"Saved task to storage. Total tasks: {len(tasks)}")
389
+
390
+ except Exception as e:
391
+ logger.error(f"Error saving task: {str(e)}")