import torch import logging import os import json from datetime import datetime from label_studio_ml.model import LabelStudioMLBase from transformers import AutoModelForSequenceClassification, AutoTokenizer from torch.utils.data import DataLoader from torch.optim import AdamW from sklearn.preprocessing import LabelEncoder import sys from pathlib import Path from torch.utils.data import Dataset # Get the directory containing model.py current_dir = Path(__file__).parent logger = logging.getLogger(__name__) # Move TextDataset class here class TextDataset(Dataset): def __init__(self, texts, labels, tokenizer, max_length=128): self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length) self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} item['labels'] = torch.tensor(self.labels[idx]) return item def __len__(self): return len(self.labels) class BertClassifier(LabelStudioMLBase): def __init__(self, project_id=None, label_config=None, **kwargs): super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config) # Set up model directory self.model_dir = os.path.join(os.path.dirname(__file__), 'model') os.makedirs(self.model_dir, exist_ok=True) # Parse label config to get categories from label_studio_ml.model import parse_config parsed_config = parse_config(label_config) # Extract categories from the parsed config if not parsed_config: raise ValueError("Label config parsing returned empty result") # Find the first Choices tag in the config choices_tag = None for tag_name, tag_info in parsed_config.items(): if tag_info.get('type') == 'Choices': choices_tag = tag_info break if not choices_tag: raise ValueError("No Choices tag found in label config") # Extract labels from the choices tag self.categories = choices_tag.get('labels', []) if not self.categories: raise ValueError("No categories found in label config") # Load training configuration from environment variables with defaults self.learning_rate = float(os.getenv('LEARNING_RATE', '2e-5')) self.num_train_epochs = int(os.getenv('NUM_TRAIN_EPOCHS', '3')) self.weight_decay = float(os.getenv('WEIGHT_DECAY', '0.01')) self.start_training_threshold = int(os.getenv('START_TRAINING_EACH_N_UPDATES', '1')) logger.info("=== Training Configuration ===") logger.info(f"✓ Learning rate: {self.learning_rate}") logger.info(f"✓ Number of epochs: {self.num_train_epochs}") logger.info(f"✓ Weight decay: {self.weight_decay}") logger.info(f"✓ Training threshold: {self.start_training_threshold}") logger.info("============================") # Initialize tokenizer and model architecture (but not weights yet) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') self._model = AutoModelForSequenceClassification.from_pretrained( 'bert-base-uncased', num_labels=len(self.categories) ) self._model.to(self.device) def initialize(self): """ Initialize model when server starts instead of when first prediction is requested. """ logger.info("=== INITIALIZING MODEL ===") # Load saved model if exists model_path = os.path.join(self.model_dir, 'model_state.pt') if os.path.exists(model_path): try: self._model.load_state_dict(torch.load(model_path)) logger.info(f"✓ Loaded saved model from {model_path}") except Exception as e: logger.error(f"Failed to load model: {str(e)}") logger.info("✓ Model ready") return self def predict(self, tasks, **kwargs): # Validation checks if not tasks: logger.error("No tasks received") return [] if not self._model or not self.tokenizer: logger.error("Model or tokenizer not initialized") return [] # Check if categories match the Label Studio config if not hasattr(self, 'categories') or not self.categories: logger.error("No categories configured") return [] predictions = [] for task_index, task in enumerate(tasks, 1): try: # Input validation if 'data' not in task or 'text' not in task['data']: logger.error(f"Task {task_index}: Invalid task format") continue if 'id' not in task: logger.error(f"Task {task_index}: Missing task ID") continue input_text = task['data']['text'] if not input_text or not isinstance(input_text, str): logger.error(f"Task {task_index}: Invalid input text") continue # Model prediction inputs = self.tokenizer( input_text, truncation=True, padding=True, return_tensors="pt" ).to(self.device) # Validate tokenized input if inputs['input_ids'].size(1) == 0: logger.error(f"Task {task_index}: Empty tokenized input") continue # Get model prediction self._model.eval() with torch.no_grad(): outputs = self._model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=1) # Get top 3 predictions with their probabilities top_probs, top_indices = torch.topk(probabilities, min(3, len(self.categories))) # Format choices with probabilities choices = [] for prob, idx in zip(top_probs[0], top_indices[0]): if prob.item() > 0.05: # Only include predictions with >5% confidence choices.append(self.categories[idx.item()]) if not choices: # If no prediction above threshold, use top prediction choices = [self.categories[top_indices[0][0].item()]] confidence_score = top_probs[0][0].item() # Format prediction according to Label Studio requirements prediction = { 'result': [{ 'from_name': 'sentiment', 'to_name': 'text', 'type': 'choices', 'value': { 'choices': [choices[0]] }, 'score': confidence_score }], 'model_version': str(self.model_version), 'task': task['id'] } # Validate prediction format if not self._validate_prediction(prediction): logger.error(f"Task {task_index}: Invalid prediction format") continue predictions.append(prediction) except Exception as e: logger.error(f"Error processing task {task_index}: {str(e)}", exc_info=True) continue return predictions def _validate_prediction(self, prediction): """Validate prediction format matches Label Studio requirements""" try: # Check basic structure if not isinstance(prediction, dict): logger.error("Prediction must be a dictionary") return False if 'result' not in prediction or not isinstance(prediction['result'], list): logger.error("Prediction must contain 'result' list") return False if not prediction['result']: logger.error("Prediction result list is empty") return False result = prediction['result'][0] # Check required fields required_fields = ['from_name', 'to_name', 'type', 'value'] for field in required_fields: if field not in result: logger.error(f"Missing required field: {field}") return False # Check value format if not isinstance(result['value'], dict) or 'choices' not in result['value']: logger.error("Invalid value format") return False # Check choices choices = result['value']['choices'] if not isinstance(choices, list) or not choices: logger.error("Invalid choices format") return False # Verify choice is in configured categories if choices[0] not in self.categories: logger.error(f"Predicted label '{choices[0]}' not in configured categories") return False return True except Exception as e: logger.error(f"Error validating prediction: {str(e)}") return False def fit(self, event_data, data=None, **kwargs): start_time = datetime.now() logger.info("=== FIT METHOD CALLED ===") try: if event_data == 'ANNOTATION_CREATED': # Check if we have enough annotations if self._get_annotation_count() < self.start_training_threshold: logger.info(f"Waiting for more annotations. Current: {self._get_annotation_count()}, Need: {self.start_training_threshold}") return {'status': 'ok', 'message': f'Waiting for more annotations ({self._get_annotation_count()}/{self.start_training_threshold})'} annotation = data.get('annotation', {}) task = data.get('task', {}) if not task or not annotation: logger.error("Missing task or annotation data") return {'status': 'error', 'message': 'Missing task or annotation data'} # Extract text and label text = task.get('data', {}).get('text', '') results = annotation.get('result', []) for result in results: if result.get('type') == 'choices': label = result.get('value', {}).get('choices', [])[0] logger.info(f"Training on - Text: {text[:50]}... Label: {label}") try: logger.info("Creating dataset...") dataset = TextDataset( texts=[text], labels=[self.categories.index(label)], tokenizer=self.tokenizer ) train_loader = DataLoader(dataset, batch_size=1) logger.info("✓ Dataset created") # Setup training optimizer = AdamW( self._model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) self._model.train() logger.info("Starting training...") # Training loop total_loss = 0 for epoch in range(self.num_train_epochs): logger.info(f"Starting epoch {epoch + 1}/{self.num_train_epochs}") epoch_loss = 0 for batch in train_loader: optimizer.zero_grad() # Move batch to device input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) labels = batch['labels'].to(self.device) # Forward pass outputs = self._model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) loss = outputs.loss epoch_loss += loss.item() # Backward pass loss.backward() optimizer.step() avg_epoch_loss = epoch_loss / len(train_loader) total_loss += avg_epoch_loss logger.info(f"Epoch {epoch + 1} loss: {avg_epoch_loss:.4f}") avg_training_loss = total_loss / self.num_train_epochs logger.info(f"Average training loss: {avg_training_loss:.4f}") # Save model model_path = os.path.join(self.model_dir, 'model_state.pt') torch.save(self._model.state_dict(), model_path) logger.info(f"✓ Model saved to {model_path}") return { 'status': 'ok', 'message': f'Training completed with avg loss: {avg_training_loss:.4f}' } except Exception as e: logger.error(f"Training error: {str(e)}") return {'status': 'error', 'message': str(e)} except Exception as e: logger.error(f"Error in fit method: {str(e)}") logger.error("Full error details:", exc_info=True) return {'status': 'error', 'message': str(e)} return {'status': 'ok', 'message': 'Event processed'} def _get_annotation_count(self): """Helper method to get the current annotation count""" # This is a placeholder - you'll need to implement actual counting # For now, returning 1 to allow immediate training return 1