Spaces:
Sleeping
Sleeping
| import torch | |
| import logging | |
| import os | |
| import json | |
| from datetime import datetime | |
| from label_studio_ml.model import LabelStudioMLBase | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from torch.utils.data import DataLoader | |
| from torch.optim import AdamW | |
| from sklearn.preprocessing import LabelEncoder | |
| import sys | |
| from pathlib import Path | |
| from torch.utils.data import Dataset | |
| # Get the directory containing model.py | |
| current_dir = Path(__file__).parent | |
| logger = logging.getLogger(__name__) | |
| # Move TextDataset class here | |
| class TextDataset(Dataset): | |
| def __init__(self, texts, labels, tokenizer, max_length=128): | |
| self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length) | |
| self.labels = labels | |
| def __getitem__(self, idx): | |
| item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | |
| item['labels'] = torch.tensor(self.labels[idx]) | |
| return item | |
| def __len__(self): | |
| return len(self.labels) | |
| class BertClassifier(LabelStudioMLBase): | |
| def __init__(self, project_id=None, label_config=None, **kwargs): | |
| super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config) | |
| # Set up model directory | |
| self.model_dir = os.path.join(os.path.dirname(__file__), 'model') | |
| os.makedirs(self.model_dir, exist_ok=True) | |
| # Parse label config to get categories | |
| from label_studio_ml.model import parse_config | |
| parsed_config = parse_config(label_config) | |
| # Extract categories from the parsed config | |
| if not parsed_config: | |
| raise ValueError("Label config parsing returned empty result") | |
| # Find the first Choices tag in the config | |
| choices_tag = None | |
| for tag_name, tag_info in parsed_config.items(): | |
| if tag_info.get('type') == 'Choices': | |
| choices_tag = tag_info | |
| break | |
| if not choices_tag: | |
| raise ValueError("No Choices tag found in label config") | |
| # Extract labels from the choices tag | |
| self.categories = choices_tag.get('labels', []) | |
| if not self.categories: | |
| raise ValueError("No categories found in label config") | |
| # Load training configuration from environment variables with defaults | |
| self.learning_rate = float(os.getenv('LEARNING_RATE', '2e-5')) | |
| self.num_train_epochs = int(os.getenv('NUM_TRAIN_EPOCHS', '3')) | |
| self.weight_decay = float(os.getenv('WEIGHT_DECAY', '0.01')) | |
| self.start_training_threshold = int(os.getenv('START_TRAINING_EACH_N_UPDATES', '1')) | |
| logger.info("=== Training Configuration ===") | |
| logger.info(f"✓ Learning rate: {self.learning_rate}") | |
| logger.info(f"✓ Number of epochs: {self.num_train_epochs}") | |
| logger.info(f"✓ Weight decay: {self.weight_decay}") | |
| logger.info(f"✓ Training threshold: {self.start_training_threshold}") | |
| logger.info("============================") | |
| # Initialize tokenizer and model architecture (but not weights yet) | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| self._model = AutoModelForSequenceClassification.from_pretrained( | |
| 'bert-base-uncased', | |
| num_labels=len(self.categories) | |
| ) | |
| self._model.to(self.device) | |
| def initialize(self): | |
| """ | |
| Initialize model when server starts instead of when first prediction is requested. | |
| """ | |
| logger.info("=== INITIALIZING MODEL ===") | |
| # Load saved model if exists | |
| model_path = os.path.join(self.model_dir, 'model_state.pt') | |
| if os.path.exists(model_path): | |
| try: | |
| self._model.load_state_dict(torch.load(model_path)) | |
| logger.info(f"✓ Loaded saved model from {model_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| logger.info("✓ Model ready") | |
| return self | |
| def predict(self, tasks, **kwargs): | |
| # Validation checks | |
| if not tasks: | |
| logger.error("No tasks received") | |
| return [] | |
| if not self._model or not self.tokenizer: | |
| logger.error("Model or tokenizer not initialized") | |
| return [] | |
| # Check if categories match the Label Studio config | |
| if not hasattr(self, 'categories') or not self.categories: | |
| logger.error("No categories configured") | |
| return [] | |
| predictions = [] | |
| for task_index, task in enumerate(tasks, 1): | |
| try: | |
| # Input validation | |
| if 'data' not in task or 'text' not in task['data']: | |
| logger.error(f"Task {task_index}: Invalid task format") | |
| continue | |
| if 'id' not in task: | |
| logger.error(f"Task {task_index}: Missing task ID") | |
| continue | |
| input_text = task['data']['text'] | |
| if not input_text or not isinstance(input_text, str): | |
| logger.error(f"Task {task_index}: Invalid input text") | |
| continue | |
| # Model prediction | |
| inputs = self.tokenizer( | |
| input_text, | |
| truncation=True, | |
| padding=True, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Validate tokenized input | |
| if inputs['input_ids'].size(1) == 0: | |
| logger.error(f"Task {task_index}: Empty tokenized input") | |
| continue | |
| # Get model prediction | |
| self._model.eval() | |
| with torch.no_grad(): | |
| outputs = self._model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1) | |
| # Get top 3 predictions with their probabilities | |
| top_probs, top_indices = torch.topk(probabilities, min(3, len(self.categories))) | |
| # Format choices with probabilities | |
| choices = [] | |
| for prob, idx in zip(top_probs[0], top_indices[0]): | |
| if prob.item() > 0.05: # Only include predictions with >5% confidence | |
| choices.append(self.categories[idx.item()]) | |
| if not choices: # If no prediction above threshold, use top prediction | |
| choices = [self.categories[top_indices[0][0].item()]] | |
| confidence_score = top_probs[0][0].item() | |
| # Format prediction according to Label Studio requirements | |
| prediction = { | |
| 'result': [{ | |
| 'from_name': 'sentiment', | |
| 'to_name': 'text', | |
| 'type': 'choices', | |
| 'value': { | |
| 'choices': [choices[0]] | |
| }, | |
| 'score': confidence_score | |
| }], | |
| 'model_version': str(self.model_version), | |
| 'task': task['id'] | |
| } | |
| # Validate prediction format | |
| if not self._validate_prediction(prediction): | |
| logger.error(f"Task {task_index}: Invalid prediction format") | |
| continue | |
| predictions.append(prediction) | |
| except Exception as e: | |
| logger.error(f"Error processing task {task_index}: {str(e)}", exc_info=True) | |
| continue | |
| return predictions | |
| def _validate_prediction(self, prediction): | |
| """Validate prediction format matches Label Studio requirements""" | |
| try: | |
| # Check basic structure | |
| if not isinstance(prediction, dict): | |
| logger.error("Prediction must be a dictionary") | |
| return False | |
| if 'result' not in prediction or not isinstance(prediction['result'], list): | |
| logger.error("Prediction must contain 'result' list") | |
| return False | |
| if not prediction['result']: | |
| logger.error("Prediction result list is empty") | |
| return False | |
| result = prediction['result'][0] | |
| # Check required fields | |
| required_fields = ['from_name', 'to_name', 'type', 'value'] | |
| for field in required_fields: | |
| if field not in result: | |
| logger.error(f"Missing required field: {field}") | |
| return False | |
| # Check value format | |
| if not isinstance(result['value'], dict) or 'choices' not in result['value']: | |
| logger.error("Invalid value format") | |
| return False | |
| # Check choices | |
| choices = result['value']['choices'] | |
| if not isinstance(choices, list) or not choices: | |
| logger.error("Invalid choices format") | |
| return False | |
| # Verify choice is in configured categories | |
| if choices[0] not in self.categories: | |
| logger.error(f"Predicted label '{choices[0]}' not in configured categories") | |
| return False | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error validating prediction: {str(e)}") | |
| return False | |
| def fit(self, event_data, data=None, **kwargs): | |
| start_time = datetime.now() | |
| logger.info("=== FIT METHOD CALLED ===") | |
| try: | |
| if event_data == 'ANNOTATION_CREATED': | |
| # Check if we have enough annotations | |
| if self._get_annotation_count() < self.start_training_threshold: | |
| logger.info(f"Waiting for more annotations. Current: {self._get_annotation_count()}, Need: {self.start_training_threshold}") | |
| return {'status': 'ok', 'message': f'Waiting for more annotations ({self._get_annotation_count()}/{self.start_training_threshold})'} | |
| annotation = data.get('annotation', {}) | |
| task = data.get('task', {}) | |
| if not task or not annotation: | |
| logger.error("Missing task or annotation data") | |
| return {'status': 'error', 'message': 'Missing task or annotation data'} | |
| # Extract text and label | |
| text = task.get('data', {}).get('text', '') | |
| results = annotation.get('result', []) | |
| for result in results: | |
| if result.get('type') == 'choices': | |
| label = result.get('value', {}).get('choices', [])[0] | |
| logger.info(f"Training on - Text: {text[:50]}... Label: {label}") | |
| try: | |
| logger.info("Creating dataset...") | |
| dataset = TextDataset( | |
| texts=[text], | |
| labels=[self.categories.index(label)], | |
| tokenizer=self.tokenizer | |
| ) | |
| train_loader = DataLoader(dataset, batch_size=1) | |
| logger.info("✓ Dataset created") | |
| # Setup training | |
| optimizer = AdamW( | |
| self._model.parameters(), | |
| lr=self.learning_rate, | |
| weight_decay=self.weight_decay | |
| ) | |
| self._model.train() | |
| logger.info("Starting training...") | |
| # Training loop | |
| total_loss = 0 | |
| for epoch in range(self.num_train_epochs): | |
| logger.info(f"Starting epoch {epoch + 1}/{self.num_train_epochs}") | |
| epoch_loss = 0 | |
| for batch in train_loader: | |
| optimizer.zero_grad() | |
| # Move batch to device | |
| input_ids = batch['input_ids'].to(self.device) | |
| attention_mask = batch['attention_mask'].to(self.device) | |
| labels = batch['labels'].to(self.device) | |
| # Forward pass | |
| outputs = self._model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels | |
| ) | |
| loss = outputs.loss | |
| epoch_loss += loss.item() | |
| # Backward pass | |
| loss.backward() | |
| optimizer.step() | |
| avg_epoch_loss = epoch_loss / len(train_loader) | |
| total_loss += avg_epoch_loss | |
| logger.info(f"Epoch {epoch + 1} loss: {avg_epoch_loss:.4f}") | |
| avg_training_loss = total_loss / self.num_train_epochs | |
| logger.info(f"Average training loss: {avg_training_loss:.4f}") | |
| # Save model | |
| model_path = os.path.join(self.model_dir, 'model_state.pt') | |
| torch.save(self._model.state_dict(), model_path) | |
| logger.info(f"✓ Model saved to {model_path}") | |
| return { | |
| 'status': 'ok', | |
| 'message': f'Training completed with avg loss: {avg_training_loss:.4f}' | |
| } | |
| except Exception as e: | |
| logger.error(f"Training error: {str(e)}") | |
| return {'status': 'error', 'message': str(e)} | |
| except Exception as e: | |
| logger.error(f"Error in fit method: {str(e)}") | |
| logger.error("Full error details:", exc_info=True) | |
| return {'status': 'error', 'message': str(e)} | |
| return {'status': 'ok', 'message': 'Event processed'} | |
| def _get_annotation_count(self): | |
| """Helper method to get the current annotation count""" | |
| # This is a placeholder - you'll need to implement actual counting | |
| # For now, returning 1 to allow immediate training | |
| return 1 | |