rlhf_docker / model.py
b2u's picture
epochs = 3
d643e60
import torch
import logging
import os
import json
from datetime import datetime
from label_studio_ml.model import LabelStudioMLBase
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.utils.data import DataLoader
from torch.optim import AdamW
from sklearn.preprocessing import LabelEncoder
import sys
from pathlib import Path
from torch.utils.data import Dataset
# Get the directory containing model.py
current_dir = Path(__file__).parent
logger = logging.getLogger(__name__)
# Move TextDataset class here
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=128):
self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length)
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
class BertClassifier(LabelStudioMLBase):
def __init__(self, project_id=None, label_config=None, **kwargs):
super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
# Set up model directory
self.model_dir = os.path.join(os.path.dirname(__file__), 'model')
os.makedirs(self.model_dir, exist_ok=True)
# Parse label config to get categories
from label_studio_ml.model import parse_config
parsed_config = parse_config(label_config)
# Extract categories from the parsed config
if not parsed_config:
raise ValueError("Label config parsing returned empty result")
# Find the first Choices tag in the config
choices_tag = None
for tag_name, tag_info in parsed_config.items():
if tag_info.get('type') == 'Choices':
choices_tag = tag_info
break
if not choices_tag:
raise ValueError("No Choices tag found in label config")
# Extract labels from the choices tag
self.categories = choices_tag.get('labels', [])
if not self.categories:
raise ValueError("No categories found in label config")
# Load training configuration from environment variables with defaults
self.learning_rate = float(os.getenv('LEARNING_RATE', '2e-5'))
self.num_train_epochs = int(os.getenv('NUM_TRAIN_EPOCHS', '3'))
self.weight_decay = float(os.getenv('WEIGHT_DECAY', '0.01'))
self.start_training_threshold = int(os.getenv('START_TRAINING_EACH_N_UPDATES', '1'))
logger.info("=== Training Configuration ===")
logger.info(f"✓ Learning rate: {self.learning_rate}")
logger.info(f"✓ Number of epochs: {self.num_train_epochs}")
logger.info(f"✓ Weight decay: {self.weight_decay}")
logger.info(f"✓ Training threshold: {self.start_training_threshold}")
logger.info("============================")
# Initialize tokenizer and model architecture (but not weights yet)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self._model = AutoModelForSequenceClassification.from_pretrained(
'bert-base-uncased',
num_labels=len(self.categories)
)
self._model.to(self.device)
def initialize(self):
"""
Initialize model when server starts instead of when first prediction is requested.
"""
logger.info("=== INITIALIZING MODEL ===")
# Load saved model if exists
model_path = os.path.join(self.model_dir, 'model_state.pt')
if os.path.exists(model_path):
try:
self._model.load_state_dict(torch.load(model_path))
logger.info(f"✓ Loaded saved model from {model_path}")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
logger.info("✓ Model ready")
return self
def predict(self, tasks, **kwargs):
# Validation checks
if not tasks:
logger.error("No tasks received")
return []
if not self._model or not self.tokenizer:
logger.error("Model or tokenizer not initialized")
return []
# Check if categories match the Label Studio config
if not hasattr(self, 'categories') or not self.categories:
logger.error("No categories configured")
return []
predictions = []
for task_index, task in enumerate(tasks, 1):
try:
# Input validation
if 'data' not in task or 'text' not in task['data']:
logger.error(f"Task {task_index}: Invalid task format")
continue
if 'id' not in task:
logger.error(f"Task {task_index}: Missing task ID")
continue
input_text = task['data']['text']
if not input_text or not isinstance(input_text, str):
logger.error(f"Task {task_index}: Invalid input text")
continue
# Model prediction
inputs = self.tokenizer(
input_text,
truncation=True,
padding=True,
return_tensors="pt"
).to(self.device)
# Validate tokenized input
if inputs['input_ids'].size(1) == 0:
logger.error(f"Task {task_index}: Empty tokenized input")
continue
# Get model prediction
self._model.eval()
with torch.no_grad():
outputs = self._model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
# Get top 3 predictions with their probabilities
top_probs, top_indices = torch.topk(probabilities, min(3, len(self.categories)))
# Format choices with probabilities
choices = []
for prob, idx in zip(top_probs[0], top_indices[0]):
if prob.item() > 0.05: # Only include predictions with >5% confidence
choices.append(self.categories[idx.item()])
if not choices: # If no prediction above threshold, use top prediction
choices = [self.categories[top_indices[0][0].item()]]
confidence_score = top_probs[0][0].item()
# Format prediction according to Label Studio requirements
prediction = {
'result': [{
'from_name': 'sentiment',
'to_name': 'text',
'type': 'choices',
'value': {
'choices': [choices[0]]
},
'score': confidence_score
}],
'model_version': str(self.model_version),
'task': task['id']
}
# Validate prediction format
if not self._validate_prediction(prediction):
logger.error(f"Task {task_index}: Invalid prediction format")
continue
predictions.append(prediction)
except Exception as e:
logger.error(f"Error processing task {task_index}: {str(e)}", exc_info=True)
continue
return predictions
def _validate_prediction(self, prediction):
"""Validate prediction format matches Label Studio requirements"""
try:
# Check basic structure
if not isinstance(prediction, dict):
logger.error("Prediction must be a dictionary")
return False
if 'result' not in prediction or not isinstance(prediction['result'], list):
logger.error("Prediction must contain 'result' list")
return False
if not prediction['result']:
logger.error("Prediction result list is empty")
return False
result = prediction['result'][0]
# Check required fields
required_fields = ['from_name', 'to_name', 'type', 'value']
for field in required_fields:
if field not in result:
logger.error(f"Missing required field: {field}")
return False
# Check value format
if not isinstance(result['value'], dict) or 'choices' not in result['value']:
logger.error("Invalid value format")
return False
# Check choices
choices = result['value']['choices']
if not isinstance(choices, list) or not choices:
logger.error("Invalid choices format")
return False
# Verify choice is in configured categories
if choices[0] not in self.categories:
logger.error(f"Predicted label '{choices[0]}' not in configured categories")
return False
return True
except Exception as e:
logger.error(f"Error validating prediction: {str(e)}")
return False
def fit(self, event_data, data=None, **kwargs):
start_time = datetime.now()
logger.info("=== FIT METHOD CALLED ===")
try:
if event_data == 'ANNOTATION_CREATED':
# Check if we have enough annotations
if self._get_annotation_count() < self.start_training_threshold:
logger.info(f"Waiting for more annotations. Current: {self._get_annotation_count()}, Need: {self.start_training_threshold}")
return {'status': 'ok', 'message': f'Waiting for more annotations ({self._get_annotation_count()}/{self.start_training_threshold})'}
annotation = data.get('annotation', {})
task = data.get('task', {})
if not task or not annotation:
logger.error("Missing task or annotation data")
return {'status': 'error', 'message': 'Missing task or annotation data'}
# Extract text and label
text = task.get('data', {}).get('text', '')
results = annotation.get('result', [])
for result in results:
if result.get('type') == 'choices':
label = result.get('value', {}).get('choices', [])[0]
logger.info(f"Training on - Text: {text[:50]}... Label: {label}")
try:
logger.info("Creating dataset...")
dataset = TextDataset(
texts=[text],
labels=[self.categories.index(label)],
tokenizer=self.tokenizer
)
train_loader = DataLoader(dataset, batch_size=1)
logger.info("✓ Dataset created")
# Setup training
optimizer = AdamW(
self._model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay
)
self._model.train()
logger.info("Starting training...")
# Training loop
total_loss = 0
for epoch in range(self.num_train_epochs):
logger.info(f"Starting epoch {epoch + 1}/{self.num_train_epochs}")
epoch_loss = 0
for batch in train_loader:
optimizer.zero_grad()
# Move batch to device
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
# Forward pass
outputs = self._model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs.loss
epoch_loss += loss.item()
# Backward pass
loss.backward()
optimizer.step()
avg_epoch_loss = epoch_loss / len(train_loader)
total_loss += avg_epoch_loss
logger.info(f"Epoch {epoch + 1} loss: {avg_epoch_loss:.4f}")
avg_training_loss = total_loss / self.num_train_epochs
logger.info(f"Average training loss: {avg_training_loss:.4f}")
# Save model
model_path = os.path.join(self.model_dir, 'model_state.pt')
torch.save(self._model.state_dict(), model_path)
logger.info(f"✓ Model saved to {model_path}")
return {
'status': 'ok',
'message': f'Training completed with avg loss: {avg_training_loss:.4f}'
}
except Exception as e:
logger.error(f"Training error: {str(e)}")
return {'status': 'error', 'message': str(e)}
except Exception as e:
logger.error(f"Error in fit method: {str(e)}")
logger.error("Full error details:", exc_info=True)
return {'status': 'error', 'message': str(e)}
return {'status': 'ok', 'message': 'Event processed'}
def _get_annotation_count(self):
"""Helper method to get the current annotation count"""
# This is a placeholder - you'll need to implement actual counting
# For now, returning 1 to allow immediate training
return 1