b2u commited on
Commit
150b3d1
·
1 Parent(s): b8bda62

logs cleanup

Browse files
Files changed (1) hide show
  1. model.py +29 -68
model.py CHANGED
@@ -76,81 +76,42 @@ class BertClassifier(LabelStudioMLBase):
76
  logger.error("Full error details:", exc_info=True)
77
 
78
  def predict(self, tasks, **kwargs):
79
- """Make predictions for tasks"""
80
- predictions = []
 
81
 
82
- try:
83
- logger.info("=== PREDICT METHOD CALLED ===")
84
- logger.info(f"Number of tasks received: {len(tasks)}")
85
-
86
- if self._model is None or self.tokenizer is None:
87
- logger.error("Model or tokenizer not initialized")
88
- return []
 
 
 
 
 
89
 
90
- for task in tasks:
91
- logger.info(f"Processing task ID: {task.get('id')}")
92
- text = task['data'].get('text', '')
93
- logger.info(f"Text to predict: {text[:100]}...")
 
 
 
94
 
95
- try:
96
- inputs = self.tokenizer(
97
- text,
98
- truncation=True,
99
- padding=True,
100
- return_tensors='pt'
101
- ).to(self.device)
102
-
103
- self._model.eval()
104
- with torch.no_grad():
105
- outputs = self._model(**inputs)
106
- probs = torch.softmax(outputs.logits, dim=1)
107
- predicted_idx = torch.argmax(probs, dim=1).item()
108
- confidence = probs[0][predicted_idx].item()
109
-
110
- predicted_category = self.categories[predicted_idx]
111
- logger.info(f"Predicted category: {predicted_category} with confidence: {confidence:.4f}")
112
-
113
- prediction = {
114
- 'result': [{
115
- 'from_name': 'sentiment',
116
- 'to_name': 'text',
117
- 'type': 'choices',
118
- 'value': {
119
- 'choices': [predicted_category]
120
- }
121
- }],
122
- 'score': confidence,
123
- 'model_version': 'bert-base-uncased-v1'
124
- }
125
- predictions.append(prediction)
126
-
127
- except Exception as e:
128
- logger.error(f"Error processing individual task: {str(e)}")
129
- logger.error("Full error details:", exc_info=True)
130
- predictions.append({
131
- 'result': [],
132
- 'score': 0,
133
- 'model_version': 'bert-base-uncased-v1'
134
- })
135
-
136
- except Exception as e:
137
- logger.error(f"Error in predict: {str(e)}")
138
- logger.error("Full error details:", exc_info=True)
139
- return []
140
-
141
  logger.info(f"Returning {len(predictions)} predictions")
142
  return predictions
143
 
144
  def fit(self, event_data, data=None, **kwargs):
145
  """Train the model on a single annotation."""
146
  start_time = datetime.now()
147
- logger.info(f"=== FIT METHOD CALLED ===")
148
- logger.info(f"Event data: {event_data}")
149
- logger.info(f"Data received: {json.dumps(data, indent=2)}")
150
 
151
  try:
152
  if event_data == 'ANNOTATION_CREATED':
153
- logger.info("Processing ANNOTATION_CREATED event")
154
  annotation = data.get('annotation', {})
155
  task = data.get('task', {})
156
 
@@ -175,12 +136,12 @@ class BertClassifier(LabelStudioMLBase):
175
  tokenizer=self.tokenizer
176
  )
177
  train_loader = DataLoader(dataset, batch_size=1)
178
- logger.info("Dataset created successfully")
179
 
180
  # Setup training
181
  optimizer = AdamW(self._model.parameters(), lr=2e-5)
182
  self._model.train()
183
- logger.info("Starting training loop...")
184
 
185
  # Single example training
186
  for batch in train_loader:
@@ -199,16 +160,16 @@ class BertClassifier(LabelStudioMLBase):
199
  )
200
 
201
  loss = outputs.loss
202
- logger.info(f"Training loss: {loss.item()}")
203
 
204
  # Backward pass
205
  loss.backward()
206
  optimizer.step()
207
 
208
- # Save the model after training
209
  model_path = os.path.join(self.model_dir, 'model_state.pt')
210
  torch.save(self._model.state_dict(), model_path)
211
- logger.info(f"Model saved to {model_path}")
212
 
213
  return {
214
  'status': 'ok',
 
76
  logger.error("Full error details:", exc_info=True)
77
 
78
  def predict(self, tasks, **kwargs):
79
+ """Generate predictions for a list of tasks."""
80
+ logger.info("=== PREDICT METHOD CALLED ===")
81
+ logger.info(f"Number of tasks: {len(tasks)}")
82
 
83
+ # Verify model state
84
+ model_path = os.path.join(self.model_dir, 'model_state.pt')
85
+ if os.path.exists(model_path):
86
+ logger.info(f"✓ Using trained model from: {model_path}")
87
+ else:
88
+ logger.info(" No trained model found, using initial state")
89
+
90
+ predictions = []
91
+ for task in tasks:
92
+ task_id = task['id']
93
+ text = task['data']['text']
94
+ logger.info(f"Processing task {task_id} - Text: {text[:50]}...")
95
 
96
+ try:
97
+ # ... prediction code ...
98
+ logger.info(f"Predicted category: {predicted_category} with confidence: {confidence:.4f}")
99
+
100
+ except Exception as e:
101
+ logger.error(f"Error predicting task {task_id}: {str(e)}")
102
+ continue
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  logger.info(f"Returning {len(predictions)} predictions")
105
  return predictions
106
 
107
  def fit(self, event_data, data=None, **kwargs):
108
  """Train the model on a single annotation."""
109
  start_time = datetime.now()
110
+ logger.info("=== FIT METHOD CALLED ===")
111
+ logger.info(f"Event type: {event_data}")
 
112
 
113
  try:
114
  if event_data == 'ANNOTATION_CREATED':
 
115
  annotation = data.get('annotation', {})
116
  task = data.get('task', {})
117
 
 
136
  tokenizer=self.tokenizer
137
  )
138
  train_loader = DataLoader(dataset, batch_size=1)
139
+ logger.info("Dataset created")
140
 
141
  # Setup training
142
  optimizer = AdamW(self._model.parameters(), lr=2e-5)
143
  self._model.train()
144
+ logger.info("Starting training...")
145
 
146
  # Single example training
147
  for batch in train_loader:
 
160
  )
161
 
162
  loss = outputs.loss
163
+ logger.info(f"Training loss: {loss.item():.4f}")
164
 
165
  # Backward pass
166
  loss.backward()
167
  optimizer.step()
168
 
169
+ # Save the model
170
  model_path = os.path.join(self.model_dir, 'model_state.pt')
171
  torch.save(self._model.state_dict(), model_path)
172
+ logger.info(f"Model saved to {model_path}")
173
 
174
  return {
175
  'status': 'ok',