b2u commited on
Commit
535fc0a
·
1 Parent(s): 0e770ae

debugging prediction request and response

Browse files
Files changed (1) hide show
  1. model.py +36 -41
model.py CHANGED
@@ -105,64 +105,59 @@ class BertClassifier(LabelStudioMLBase):
105
  return self
106
 
107
  def predict(self, tasks, **kwargs):
108
- """Generate predictions for a list of tasks."""
 
 
 
 
 
 
 
 
109
  logger.info("=== PREDICT METHOD CALLED ===")
 
110
  logger.info(f"Number of tasks: {len(tasks)}")
111
-
112
- # Verify model state
113
- model_path = os.path.join(self.model_dir, 'model_state.pt')
114
- if os.path.exists(model_path):
115
- logger.info(f"✓ Using trained model from: {model_path}")
116
- else:
117
- logger.info("✗ No trained model found, using initial state")
118
-
119
  predictions = []
120
- for task in tasks:
121
- task_id = task['id']
122
- text = task['data']['text']
123
- logger.info(f"Processing task {task_id} - Text: {text[:50]}...")
124
-
125
  try:
126
- # Prepare the text
127
- inputs = self.tokenizer(
128
- text,
129
- truncation=True,
130
- padding=True,
131
- return_tensors='pt'
132
- ).to(self.device)
133
 
134
- # Get model predictions
135
- self._model.eval()
136
- with torch.no_grad():
137
- outputs = self._model(**inputs)
138
-
139
- # Get predicted category and confidence
140
- probabilities = torch.softmax(outputs.logits, dim=1)
141
- confidence, predicted_idx = torch.max(probabilities, dim=1)
142
- predicted_category = self.categories[predicted_idx.item()]
143
- confidence = confidence.item()
144
 
145
- logger.info(f"Predicted category: {predicted_category} with confidence: {confidence:.4f}")
 
 
146
 
147
- # Format prediction for Label Studio
148
- predictions.append({
149
  'result': [{
150
  'from_name': 'sentiment',
151
  'to_name': 'text',
152
  'type': 'choices',
153
  'value': {
154
- 'choices': [predicted_category]
155
- }
 
156
  }],
157
- 'score': confidence,
158
- 'model_version': 'bert-base-uncased-v1'
159
- })
 
160
 
161
  except Exception as e:
162
- logger.error(f"Error predicting task {task_id}: {str(e)}")
163
  continue
164
-
165
  logger.info(f"Returning {len(predictions)} predictions")
 
166
  return predictions
167
 
168
  def fit(self, event_data, data=None, **kwargs):
 
105
  return self
106
 
107
  def predict(self, tasks, **kwargs):
108
+ """
109
+ Tasks is a list of tasks with the following fields:
110
+ {
111
+ "id": 123,
112
+ "data": {
113
+ "text": "Example text"
114
+ }
115
+ }
116
+ """
117
  logger.info("=== PREDICT METHOD CALLED ===")
118
+ logger.info(f"Received tasks: {json.dumps(tasks, indent=2)}")
119
  logger.info(f"Number of tasks: {len(tasks)}")
120
+
 
 
 
 
 
 
 
121
  predictions = []
122
+
123
+ for task_index, task in enumerate(tasks, 1):
 
 
 
124
  try:
125
+ # Log the specific task being processed
126
+ logger.info(f"Processing task {task_index} - Text: {task['data'].get('text', '')[:20]}...")
 
 
 
 
 
127
 
128
+ # Log model state
129
+ model_path = os.path.join(self.model_dir, 'model_state.pt')
130
+ if os.path.exists(model_path):
131
+ logger.info("✓ Using trained model")
132
+ else:
133
+ logger.info("✗ No trained model found, using initial state")
 
 
 
 
134
 
135
+ # Get model prediction
136
+ predicted_label, confidence = self._get_prediction(task['data']['text'])
137
+ logger.info(f"Predicted category: {predicted_label} with confidence: {confidence:.4f}")
138
 
139
+ # Format the prediction for Label Studio
140
+ prediction = {
141
  'result': [{
142
  'from_name': 'sentiment',
143
  'to_name': 'text',
144
  'type': 'choices',
145
  'value': {
146
+ 'choices': [predicted_label]
147
+ },
148
+ 'score': confidence
149
  }],
150
+ 'model_version': self.model_version,
151
+ 'task': task['id']
152
+ }
153
+ predictions.append(prediction)
154
 
155
  except Exception as e:
156
+ logger.error(f"Error predicting task {task_index}: {str(e)}")
157
  continue
158
+
159
  logger.info(f"Returning {len(predictions)} predictions")
160
+ logger.info(f"Predictions: {json.dumps(predictions, indent=2)}")
161
  return predictions
162
 
163
  def fit(self, event_data, data=None, **kwargs):