b2u commited on
Commit
8f82581
·
1 Parent(s): ad8c734

debugging why predictions are not comming

Browse files
Files changed (1) hide show
  1. model.py +20 -18
model.py CHANGED
@@ -20,41 +20,43 @@ class BertClassifier(LabelStudioMLBase):
20
  self._model = None
21
  self.tokenizer = None
22
 
23
- def predict(self, tasks, context=None, **kwargs):
24
- """
25
- tasks: Label Studio tasks in JSON format
26
- context: Label Studio context in JSON format (for interactive labeling)
27
- returns: predictions array in JSON format
28
- """
29
  predictions = []
30
 
31
  try:
32
- # Process each task
 
 
33
  for task in tasks:
34
- # Get text from the task
 
 
 
35
  text = task['data'].get('text', '')
36
- if not text:
37
- continue
38
-
39
- # Make prediction
40
  prediction = {
41
  'result': [{
42
- 'from_name': 'sentiment', # must match your labeling config
43
- 'to_name': 'text', # must match your labeling config
44
  'type': 'choices',
45
  'value': {
46
- 'choices': ['some_label'] # your predicted label
47
  }
48
  }],
49
- 'score': 0.0, # confidence score between 0 and 1
50
- 'model_version': self.model_dir
51
  }
 
52
  predictions.append(prediction)
53
 
54
  except Exception as e:
55
- print(f"Error in predict: {str(e)}")
 
56
  return []
57
 
 
58
  return predictions
59
 
60
  def fit(self, completions, workdir=None, **kwargs):
 
20
  self._model = None
21
  self.tokenizer = None
22
 
23
+ def predict(self, tasks, **kwargs):
24
+ """Make predictions for tasks"""
 
 
 
 
25
  predictions = []
26
 
27
  try:
28
+ logger.info("=== PREDICT METHOD CALLED ===")
29
+ logger.info(f"Number of tasks received: {len(tasks)}")
30
+
31
  for task in tasks:
32
+ logger.info(f"Processing task ID: {task.get('id')}")
33
+ logger.info(f"Task content: {json.dumps(task, indent=2)}")
34
+
35
+ # Get the text to classify
36
  text = task['data'].get('text', '')
37
+ logger.info(f"Text to predict: {text[:100]}...")
38
+
 
 
39
  prediction = {
40
  'result': [{
41
+ 'from_name': 'sentiment',
42
+ 'to_name': 'text',
43
  'type': 'choices',
44
  'value': {
45
+ 'choices': ['brand']
46
  }
47
  }],
48
+ 'score': 0.5,
49
+ 'model_version': 'v1'
50
  }
51
+ logger.info(f"Generated prediction: {json.dumps(prediction, indent=2)}")
52
  predictions.append(prediction)
53
 
54
  except Exception as e:
55
+ logger.error(f"Error in predict: {str(e)}")
56
+ logger.error("Full error details:", exc_info=True)
57
  return []
58
 
59
+ logger.info(f"Returning {len(predictions)} predictions")
60
  return predictions
61
 
62
  def fit(self, completions, workdir=None, **kwargs):