b2u commited on
Commit
b9a9837
·
1 Parent(s): 49d18ad

simplifying

Browse files
Files changed (1) hide show
  1. model.py +30 -151
model.py CHANGED
@@ -141,156 +141,35 @@ class BertClassifier(LabelStudioMLBase):
141
  logger.info(f"Returning {len(predictions)} predictions")
142
  return predictions
143
 
144
- def fit(self, completions, workdir=None, **kwargs):
145
- try:
146
- logger.info('=== STARTING MODEL TRAINING ===')
147
- logger.info(f'Received signal: {completions}')
148
-
149
- # If we receive a training signal, fetch the actual completions from Label Studio
150
- if isinstance(completions, str) and completions in ['START_TRAINING', 'ANNOTATION_CREATED']:
151
- try:
152
- # Get completions from Label Studio using the SDK
153
- annotations = self.get_completions()
154
- logger.info(f'Fetched {len(annotations)} annotations from Label Studio')
155
- completions = annotations
156
- except Exception as e:
157
- logger.error(f"Error fetching completions from Label Studio: {str(e)}")
158
- logger.error("Full error details:", exc_info=True)
159
- return {'status': 'error', 'message': 'Failed to fetch completions'}
160
-
161
- if not completions:
162
- logger.error("No completions to process")
163
- return {'status': 'error', 'message': 'No completions available'}
164
-
165
- texts = []
166
- labels = []
167
-
168
- # If completions is a list of single characters, join them
169
- if isinstance(completions, list) and all(isinstance(c, str) and len(c) == 1 for c in completions):
170
- completions = ''.join(completions)
171
- logger.info(f'Joined completions: {completions}')
172
-
173
- # Handle completions as a single string if needed
174
- if isinstance(completions, str):
175
- try:
176
- completions = json.loads(completions)
177
- logger.info('Successfully parsed completions JSON')
178
- except json.JSONDecodeError as e:
179
- logger.error(f"Failed to parse completions string as JSON: {str(e)}")
180
- logger.error(f"Problematic string: {completions}")
181
- return {'status': 'error', 'message': 'Invalid completions format'}
182
-
183
- # Ensure completions is a list
184
- if not isinstance(completions, list):
185
- completions = [completions]
186
-
187
- logger.info(f'Processing {len(completions)} items')
188
-
189
- for completion in completions:
190
- logger.info(f"Completion type: {type(completion)}")
191
- logger.info(f"Completion content: {completion}")
192
-
193
- try:
194
- # Convert string completion to dict if needed
195
- if isinstance(completion, str):
196
- completion = json.loads(completion)
197
-
198
- # Extract completion data
199
- completion_id = completion.get('id', 'unknown')
200
- logger.info(f"Processing completion ID: {completion_id}")
201
-
202
- # Get the task data containing the text
203
- text = completion.get('data', {}).get('text', '')
204
-
205
- # Get annotations/results
206
- annotations = completion.get('annotations', [])
207
- if not annotations and 'result' in completion:
208
- annotations = [{'result': completion['result']}]
209
-
210
- # Process each annotation
211
- for annotation in annotations:
212
- results = annotation.get('result', [])
213
-
214
- # Find the choices result
215
- for result in results:
216
- if result.get('type') == 'choices':
217
- choices = result.get('value', {}).get('choices', [])
218
- if choices:
219
- label = choices[0] # Take the first choice
220
- if text and label:
221
- texts.append(text)
222
- labels.append(label)
223
- logger.info(f"Added example - Text: {text[:50]}... Label: {label}")
224
 
225
- except Exception as e:
226
- logger.error(f"Error processing completion: {str(e)}")
227
- logger.error("Full error details:", exc_info=True)
228
- continue
229
-
230
- if not texts or not labels:
231
- logger.error("No valid training examples found")
232
- return {'status': 'error', 'message': 'No valid training examples found'}
233
-
234
- # Convert labels to integers
235
- label_encoder = LabelEncoder()
236
- encoded_labels = label_encoder.fit_transform(labels)
237
-
238
- # Save label encoder for inference
239
- self.label_encoder = label_encoder
240
- logger.info(f"Label mapping: {dict(zip(label_encoder.classes_, range(len(label_encoder.classes_))))}")
241
-
242
- # Create dataset
243
- dataset = TextDataset(texts, encoded_labels, self.tokenizer)
244
- dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
245
-
246
- # Training settings
247
- optimizer = AdamW(self.model.parameters(), lr=float(os.getenv('LEARNING_RATE', '2e-5')))
248
- num_epochs = int(os.getenv('NUM_TRAIN_EPOCHS', '3'))
249
-
250
- # Training loop
251
- logger.info(f"Starting training for {num_epochs} epochs")
252
- self.model.train()
253
-
254
- for epoch in range(num_epochs):
255
- total_loss = 0
256
- for batch in dataloader:
257
- optimizer.zero_grad()
258
-
259
- input_ids = batch['input_ids'].to(self.device)
260
- attention_mask = batch['attention_mask'].to(self.device)
261
- labels = batch['labels'].to(self.device)
262
-
263
- outputs = self.model(
264
- input_ids=input_ids,
265
- attention_mask=attention_mask,
266
- labels=labels
267
- )
268
-
269
- loss = outputs.loss
270
- total_loss += loss.item()
271
-
272
- loss.backward()
273
- optimizer.step()
274
 
275
- avg_loss = total_loss / len(dataloader)
276
- logger.info(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
277
-
278
- # Save the fine-tuned model
279
- model_dir = os.path.join(os.getenv('MODEL_DIR', ''), os.getenv('FINETUNED_MODEL_NAME', 'finetuned_model'))
280
- os.makedirs(model_dir, exist_ok=True)
281
- self.model.save_pretrained(model_dir)
282
- self.tokenizer.save_pretrained(model_dir)
283
-
284
- # Save label encoder
285
- with open(os.path.join(model_dir, 'label_encoder.json'), 'w') as f:
286
- json.dump({
287
- 'classes': label_encoder.classes_.tolist()
288
- }, f)
289
-
290
- logger.info(f"Model and label encoder saved to {model_dir}")
291
- return {'status': 'ok', 'message': f'Training completed with {len(texts)} examples'}
292
-
293
- except Exception as e:
294
- logger.error(f"Error during training: {str(e)}")
295
- logger.error("Full error details:", exc_info=True)
296
- return {'status': 'error', 'message': str(e)}
 
141
  logger.info(f"Returning {len(predictions)} predictions")
142
  return predictions
143
 
144
+ def fit(self, event, data, **kwargs):
145
+ """Train the model on the labeled data."""
146
+ logger.info(f"Received event: {event}")
147
+
148
+ # Check if the event is one that should trigger training
149
+ if event in ['ANNOTATION_CREATED', 'ANNOTATION_UPDATED']:
150
+ try:
151
+ # Fetch the full annotation data if not included in the payload
152
+ task_id = data.get('task_id')
153
+ if task_id:
154
+ annotation = self.label_studio_client.get_task(task_id)
155
+ logger.info(f"Fetched annotation for task ID: {task_id}")
156
+ else:
157
+ logger.error("No task ID found in event data")
158
+ return {'status': 'error', 'message': 'No task ID found'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ # Extract text and label from the annotation
161
+ text = annotation.get('data', {}).get('text', '')
162
+ results = annotation.get('annotations', [{}])[0].get('result', [])
163
+ for result in results:
164
+ if result.get('type') == 'choices':
165
+ label = result.get('value', {}).get('choices', [])[0]
166
+ # Add your training logic here using text and label
167
+ logger.info(f"Training on text: {text[:50]}... with label: {label}")
168
+ # Example: self.train_model(text, label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ except Exception as e:
171
+ logger.error(f"Error during training: {str(e)}")
172
+ logger.error("Full error details:", exc_info=True)
173
+ return {'status': 'error', 'message': str(e)}
174
+
175
+ return {'status': 'ok', 'message': 'Training completed'}