b2u commited on
Commit
fa5ac26
·
1 Parent(s): a3330c8

debugging

Browse files
Files changed (1) hide show
  1. model.py +54 -33
model.py CHANGED
@@ -76,6 +76,9 @@ class BertClassifier(LabelStudioMLBase):
76
  logger.info(f"Initializing BertClassifier with project_id: {project_id}")
77
  logger.info(f"Label config: {label_config}")
78
 
 
 
 
79
  self.label_encoder = LabelEncoder()
80
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
  self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
@@ -178,19 +181,24 @@ class BertClassifier(LabelStudioMLBase):
178
  return predictions
179
 
180
  def get_tasks(self):
181
- """Get tasks from local storage"""
182
  try:
183
- storage_path = os.path.join(self.model_dir, 'tasks.json')
184
- if os.path.exists(storage_path):
185
- with open(storage_path, 'r') as f:
186
- tasks = json.load(f)
187
- logger.info(f"Loaded {len(tasks)} tasks from storage")
188
- return tasks
189
- else:
190
- logger.warning("No tasks found in storage")
191
- return []
 
 
 
 
192
  except Exception as e:
193
- logger.error(f"Error retrieving tasks: {str(e)}")
 
194
  return []
195
 
196
  def fit(self, completions, workdir=None, **kwargs):
@@ -207,25 +215,13 @@ class BertClassifier(LabelStudioMLBase):
207
  # Extract training data
208
  texts, labels = [], []
209
 
210
- # If completions is a string (like "START_TRAINING"), try to get tasks
211
- if isinstance(completions, str):
212
- tasks = self.get_tasks()
213
- logger.info(f"Retrieved {len(tasks)} tasks from storage")
214
- # Debug first task if available
215
- if tasks:
216
- logger.info(f"First task content: {json.dumps(tasks[0], indent=2)}")
217
- else:
218
- # If completions is a list, use it directly
219
- tasks = completions if isinstance(completions, list) else [completions]
220
- logger.info(f"Using {len(tasks)} tasks from completions")
221
- if tasks:
222
- logger.info(f"First completion content: {json.dumps(tasks[0], indent=2)}")
223
-
224
- logger.info(f"Processing {len(tasks)} tasks")
225
 
226
  # Get interface info
227
- from_name, to_name, value = self.label_interface.get_first_tag_occurence('Choices', 'Text')
228
- logger.info(f"Interface info: from_name={from_name}, to_name={to_name}, value={value}")
229
 
230
  for task in tasks:
231
  try:
@@ -237,17 +233,20 @@ class BertClassifier(LabelStudioMLBase):
237
 
238
  # Get annotations
239
  annotations = task.get('annotations', [])
240
- logger.info(f"Found {len(annotations)} annotations for task {task.get('id')}")
 
 
241
 
242
  if not annotations:
243
  logger.warning(f"No annotations found for task {task.get('id')}")
244
  continue
245
 
246
  for annotation in annotations:
247
- try:
248
- # Debug annotation content
249
- logger.info(f"Annotation content: {json.dumps(annotation, indent=2)}")
250
 
 
251
  # Get choices from result
252
  results = annotation.get('result', [])
253
  if not results:
@@ -263,7 +262,7 @@ class BertClassifier(LabelStudioMLBase):
263
  texts.append(text)
264
  labels.append(label)
265
  break
266
-
267
  except Exception as e:
268
  logger.error(f"Error processing annotation: {str(e)}")
269
  continue
@@ -398,3 +397,25 @@ class BertClassifier(LabelStudioMLBase):
398
  except Exception as e:
399
  logger.error(f"Error saving task: {str(e)}")
400
  logger.error("Full error details:", exc_info=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  logger.info(f"Initializing BertClassifier with project_id: {project_id}")
77
  logger.info(f"Label config: {label_config}")
78
 
79
+ # Initialize Label Studio client
80
+ self.label_studio_client = self.connect_to_label_studio()
81
+
82
  self.label_encoder = LabelEncoder()
83
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
  self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
 
181
  return predictions
182
 
183
  def get_tasks(self):
184
+ """Get tasks from Label Studio"""
185
  try:
186
+ # Get tasks from Label Studio API
187
+ params = {'project': self.project_id} if self.project_id else {}
188
+ response = self.label_studio_client.make_request('GET', '/api/tasks', params=params)
189
+ tasks = response.json()
190
+
191
+ logger.info(f"Retrieved {len(tasks)} tasks from Label Studio API")
192
+
193
+ # Debug first task if available
194
+ if tasks:
195
+ logger.info(f"First task content: {json.dumps(tasks[0], indent=2)}")
196
+
197
+ return tasks
198
+
199
  except Exception as e:
200
+ logger.error(f"Error retrieving tasks from Label Studio: {str(e)}")
201
+ logger.error("Full error details:", exc_info=True)
202
  return []
203
 
204
  def fit(self, completions, workdir=None, **kwargs):
 
215
  # Extract training data
216
  texts, labels = [], []
217
 
218
+ # Get tasks from Label Studio
219
+ tasks = self.get_tasks()
220
+ logger.info(f"Retrieved {len(tasks)} tasks from Label Studio")
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  # Get interface info
223
+ from_name = 'sentiment' # This matches your label config
224
+ to_name = 'text' # This matches your label config
225
 
226
  for task in tasks:
227
  try:
 
233
 
234
  # Get annotations
235
  annotations = task.get('annotations', [])
236
+ if annotations:
237
+ logger.info(f"Found {len(annotations)} annotations for task {task.get('id')}")
238
+ logger.info(f"Annotation content: {json.dumps(annotations[0], indent=2)}")
239
 
240
  if not annotations:
241
  logger.warning(f"No annotations found for task {task.get('id')}")
242
  continue
243
 
244
  for annotation in annotations:
245
+ # Only use completed annotations
246
+ if annotation.get('was_cancelled') or not annotation.get('completed_by'):
247
+ continue
248
 
249
+ try:
250
  # Get choices from result
251
  results = annotation.get('result', [])
252
  if not results:
 
262
  texts.append(text)
263
  labels.append(label)
264
  break
265
+
266
  except Exception as e:
267
  logger.error(f"Error processing annotation: {str(e)}")
268
  continue
 
397
  except Exception as e:
398
  logger.error(f"Error saving task: {str(e)}")
399
  logger.error("Full error details:", exc_info=True)
400
+
401
+ def connect_to_label_studio(self):
402
+ """Connect to Label Studio API"""
403
+ try:
404
+ from label_studio_sdk import Client
405
+
406
+ # Get Label Studio connection details from environment
407
+ ls_url = os.getenv('LABEL_STUDIO_URL', 'http://localhost:8080')
408
+ ls_token = os.getenv('LABEL_STUDIO_API_TOKEN')
409
+
410
+ if not ls_token:
411
+ raise ValueError("LABEL_STUDIO_API_TOKEN environment variable is not set")
412
+
413
+ # Initialize client
414
+ client = Client(url=ls_url, api_key=ls_token)
415
+ logger.info(f"Connected to Label Studio at {ls_url}")
416
+ return client
417
+
418
+ except Exception as e:
419
+ logger.error(f"Error connecting to Label Studio: {str(e)}")
420
+ logger.error("Full error details:", exc_info=True)
421
+ raise