b2u commited on
Commit
c4e7614
·
1 Parent(s): 535fc0a

debugging predictions

Browse files
Files changed (1) hide show
  1. model.py +24 -15
model.py CHANGED
@@ -105,15 +105,6 @@ class BertClassifier(LabelStudioMLBase):
105
  return self
106
 
107
  def predict(self, tasks, **kwargs):
108
- """
109
- Tasks is a list of tasks with the following fields:
110
- {
111
- "id": 123,
112
- "data": {
113
- "text": "Example text"
114
- }
115
- }
116
- """
117
  logger.info("=== PREDICT METHOD CALLED ===")
118
  logger.info(f"Received tasks: {json.dumps(tasks, indent=2)}")
119
  logger.info(f"Number of tasks: {len(tasks)}")
@@ -122,19 +113,37 @@ class BertClassifier(LabelStudioMLBase):
122
 
123
  for task_index, task in enumerate(tasks, 1):
124
  try:
125
- # Log the specific task being processed
126
  logger.info(f"Processing task {task_index} - Text: {task['data'].get('text', '')[:20]}...")
127
 
128
- # Log model state
129
  model_path = os.path.join(self.model_dir, 'model_state.pt')
130
  if os.path.exists(model_path):
131
  logger.info("✓ Using trained model")
132
  else:
133
  logger.info("✗ No trained model found, using initial state")
134
 
135
- # Get model prediction
136
- predicted_label, confidence = self._get_prediction(task['data']['text'])
137
- logger.info(f"Predicted category: {predicted_label} with confidence: {confidence:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  # Format the prediction for Label Studio
140
  prediction = {
@@ -145,7 +154,7 @@ class BertClassifier(LabelStudioMLBase):
145
  'value': {
146
  'choices': [predicted_label]
147
  },
148
- 'score': confidence
149
  }],
150
  'model_version': self.model_version,
151
  'task': task['id']
 
105
  return self
106
 
107
  def predict(self, tasks, **kwargs):
 
 
 
 
 
 
 
 
 
108
  logger.info("=== PREDICT METHOD CALLED ===")
109
  logger.info(f"Received tasks: {json.dumps(tasks, indent=2)}")
110
  logger.info(f"Number of tasks: {len(tasks)}")
 
113
 
114
  for task_index, task in enumerate(tasks, 1):
115
  try:
 
116
  logger.info(f"Processing task {task_index} - Text: {task['data'].get('text', '')[:20]}...")
117
 
 
118
  model_path = os.path.join(self.model_dir, 'model_state.pt')
119
  if os.path.exists(model_path):
120
  logger.info("✓ Using trained model")
121
  else:
122
  logger.info("✗ No trained model found, using initial state")
123
 
124
+ # Prepare the text for the model
125
+ inputs = self.tokenizer(
126
+ task['data']['text'],
127
+ truncation=True,
128
+ padding=True,
129
+ return_tensors="pt"
130
+ ).to(self.device)
131
+
132
+ # Set model to evaluation mode
133
+ self._model.eval()
134
+
135
+ # Get prediction
136
+ with torch.no_grad():
137
+ outputs = self._model(**inputs)
138
+ logits = outputs.logits
139
+ probabilities = torch.softmax(logits, dim=1)
140
+ confidence, predicted_idx = torch.max(probabilities, dim=1)
141
+
142
+ # Get predicted label
143
+ predicted_label = self.categories[predicted_idx.item()]
144
+ confidence_score = confidence.item()
145
+
146
+ logger.info(f"Predicted category: {predicted_label} with confidence: {confidence_score:.4f}")
147
 
148
  # Format the prediction for Label Studio
149
  prediction = {
 
154
  'value': {
155
  'choices': [predicted_label]
156
  },
157
+ 'score': confidence_score
158
  }],
159
  'model_version': self.model_version,
160
  'task': task['id']