File size: 15,782 Bytes
cbf01bb
 
1b11c8f
c6dedc8
4debd04
cbf01bb
1b11c8f
4debd04
 
f9c8fe4
e81ab68
0579428
6de324b
0579428
 
 
 
cbf01bb
 
6de324b
 
 
 
 
 
 
 
 
 
 
 
 
 
cbf01bb
41e8855
 
 
f85629e
 
 
 
95dbb1c
 
666a18b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3aa3ec
 
666a18b
 
 
b537bff
 
d643e60
b537bff
 
d450991
 
 
 
 
 
 
50e2ec7
0e770ae
3178a98
0e770ae
1589415
 
 
 
0e770ae
 
 
 
 
 
 
1818eaa
1589415
 
 
 
 
 
 
 
1818eaa
1589415
50e2ec7
cbf01bb
8f82581
90631ed
 
 
 
 
 
 
 
9c980be
27c34a1
 
 
 
90631ed
150b3d1
535fc0a
 
150b3d1
90631ed
 
 
 
27c34a1
 
 
 
 
98babd1
90631ed
 
 
 
 
c4e7614
98babd1
c4e7614
 
 
 
 
90631ed
 
 
 
 
98babd1
 
c4e7614
 
 
 
 
851f3a2
 
 
 
 
 
 
 
 
 
 
 
 
c4e7614
27c34a1
535fc0a
541d9b7
c6d8e9b
 
541d9b7
 
c6d8e9b
535fc0a
c4e7614
541d9b7
2ef1a4d
535fc0a
 
851f3a2
27c34a1
 
 
 
 
535fc0a
541d9b7
150b3d1
98babd1
150b3d1
535fc0a
6bdfcdd
cbf01bb
27c34a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22c8159
b262013
150b3d1
b9a9837
4a96163
 
d450991
 
 
 
 
4a96163
 
 
 
 
 
 
b262013
4a96163
 
b262013
b9a9837
 
 
b262013
4a96163
b262013
b8bda62
b262013
 
 
 
 
 
150b3d1
b262013
 
d450991
 
 
 
 
b262013
150b3d1
b262013
1589415
 
d450991
 
1589415
b262013
d450991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1589415
d450991
 
 
 
b262013
1589415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b262013
 
 
1589415
b262013
4a96163
b262013
4a96163
 
b9a9837
4a96163
d450991
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import torch
import logging
import os
import json
from datetime import datetime
from label_studio_ml.model import LabelStudioMLBase
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.utils.data import DataLoader
from torch.optim import AdamW
from sklearn.preprocessing import LabelEncoder
import sys
from pathlib import Path
from torch.utils.data import Dataset

# Get the directory containing model.py
current_dir = Path(__file__).parent

logger = logging.getLogger(__name__)

# Move TextDataset class here
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length)
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

class BertClassifier(LabelStudioMLBase):
    def __init__(self, project_id=None, label_config=None, **kwargs):
        super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
        
        # Set up model directory
        self.model_dir = os.path.join(os.path.dirname(__file__), 'model')
        os.makedirs(self.model_dir, exist_ok=True)
        
        # Parse label config to get categories
        from label_studio_ml.model import parse_config
        parsed_config = parse_config(label_config)
        
        # Extract categories from the parsed config
        if not parsed_config:
            raise ValueError("Label config parsing returned empty result")
        
        # Find the first Choices tag in the config
        choices_tag = None
        for tag_name, tag_info in parsed_config.items():
            if tag_info.get('type') == 'Choices':
                choices_tag = tag_info
                break
        
        if not choices_tag:
            raise ValueError("No Choices tag found in label config")
        
        # Extract labels from the choices tag
        self.categories = choices_tag.get('labels', [])
        if not self.categories:
            raise ValueError("No categories found in label config")
        
        # Load training configuration from environment variables with defaults
        self.learning_rate = float(os.getenv('LEARNING_RATE', '2e-5'))
        self.num_train_epochs = int(os.getenv('NUM_TRAIN_EPOCHS', '3'))
        self.weight_decay = float(os.getenv('WEIGHT_DECAY', '0.01'))
        self.start_training_threshold = int(os.getenv('START_TRAINING_EACH_N_UPDATES', '1'))
        
        logger.info("=== Training Configuration ===")
        logger.info(f"βœ“ Learning rate: {self.learning_rate}")
        logger.info(f"βœ“ Number of epochs: {self.num_train_epochs}")
        logger.info(f"βœ“ Weight decay: {self.weight_decay}")
        logger.info(f"βœ“ Training threshold: {self.start_training_threshold}")
        logger.info("============================")

        # Initialize tokenizer and model architecture (but not weights yet)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        self._model = AutoModelForSequenceClassification.from_pretrained(
            'bert-base-uncased',
            num_labels=len(self.categories)
        )
        self._model.to(self.device)

    def initialize(self):
        """
        Initialize model when server starts instead of when first prediction is requested.
        """
        logger.info("=== INITIALIZING MODEL ===")
        
        # Load saved model if exists
        model_path = os.path.join(self.model_dir, 'model_state.pt')
        if os.path.exists(model_path):
            try:
                self._model.load_state_dict(torch.load(model_path))
                logger.info(f"βœ“ Loaded saved model from {model_path}")
            except Exception as e:
                logger.error(f"Failed to load model: {str(e)}")
        
        logger.info("βœ“ Model ready")
        return self

    def predict(self, tasks, **kwargs):
        # Validation checks
        if not tasks:
            logger.error("No tasks received")
            return []
        
        if not self._model or not self.tokenizer:
            logger.error("Model or tokenizer not initialized")
            return []
            
        # Check if categories match the Label Studio config
        if not hasattr(self, 'categories') or not self.categories:
            logger.error("No categories configured")
            return []
        
        predictions = []
        
        for task_index, task in enumerate(tasks, 1):
            try:
                # Input validation
                if 'data' not in task or 'text' not in task['data']:
                    logger.error(f"Task {task_index}: Invalid task format")
                    continue
                    
                if 'id' not in task:
                    logger.error(f"Task {task_index}: Missing task ID")
                    continue
                    
                input_text = task['data']['text']
                if not input_text or not isinstance(input_text, str):
                    logger.error(f"Task {task_index}: Invalid input text")
                    continue
                
                # Model prediction
                inputs = self.tokenizer(
                    input_text,
                    truncation=True,
                    padding=True,
                    return_tensors="pt"
                ).to(self.device)
                
                # Validate tokenized input
                if inputs['input_ids'].size(1) == 0:
                    logger.error(f"Task {task_index}: Empty tokenized input")
                    continue
                
                # Get model prediction
                self._model.eval()
                with torch.no_grad():
                    outputs = self._model(**inputs)
                    logits = outputs.logits
                    probabilities = torch.softmax(logits, dim=1)
                    
                    # Get top 3 predictions with their probabilities
                    top_probs, top_indices = torch.topk(probabilities, min(3, len(self.categories)))
                    
                    # Format choices with probabilities
                    choices = []
                    for prob, idx in zip(top_probs[0], top_indices[0]):
                        if prob.item() > 0.05:  # Only include predictions with >5% confidence
                            choices.append(self.categories[idx.item()])
                    
                    if not choices:  # If no prediction above threshold, use top prediction
                        choices = [self.categories[top_indices[0][0].item()]]
                        
                    confidence_score = top_probs[0][0].item()
                
                # Format prediction according to Label Studio requirements
                prediction = {
                    'result': [{
                        'from_name': 'sentiment',
                        'to_name': 'text',
                        'type': 'choices',
                        'value': {
                            'choices': [choices[0]]
                        },
                        'score': confidence_score
                    }],
                    'model_version': str(self.model_version),
                    'task': task['id']
                }
                
                # Validate prediction format
                if not self._validate_prediction(prediction):
                    logger.error(f"Task {task_index}: Invalid prediction format")
                    continue
                    
                predictions.append(prediction)
                
            except Exception as e:
                logger.error(f"Error processing task {task_index}: {str(e)}", exc_info=True)
                continue
        
        return predictions

    def _validate_prediction(self, prediction):
        """Validate prediction format matches Label Studio requirements"""
        try:
            # Check basic structure
            if not isinstance(prediction, dict):
                logger.error("Prediction must be a dictionary")
                return False
                
            if 'result' not in prediction or not isinstance(prediction['result'], list):
                logger.error("Prediction must contain 'result' list")
                return False
                
            if not prediction['result']:
                logger.error("Prediction result list is empty")
                return False
                
            result = prediction['result'][0]
            
            # Check required fields
            required_fields = ['from_name', 'to_name', 'type', 'value']
            for field in required_fields:
                if field not in result:
                    logger.error(f"Missing required field: {field}")
                    return False
                    
            # Check value format
            if not isinstance(result['value'], dict) or 'choices' not in result['value']:
                logger.error("Invalid value format")
                return False
                
            # Check choices
            choices = result['value']['choices']
            if not isinstance(choices, list) or not choices:
                logger.error("Invalid choices format")
                return False
                
            # Verify choice is in configured categories
            if choices[0] not in self.categories:
                logger.error(f"Predicted label '{choices[0]}' not in configured categories")
                return False
                
            return True
            
        except Exception as e:
            logger.error(f"Error validating prediction: {str(e)}")
            return False

    def fit(self, event_data, data=None, **kwargs):
        start_time = datetime.now()
        logger.info("=== FIT METHOD CALLED ===")
        
        try:
            if event_data == 'ANNOTATION_CREATED':
                # Check if we have enough annotations
                if self._get_annotation_count() < self.start_training_threshold:
                    logger.info(f"Waiting for more annotations. Current: {self._get_annotation_count()}, Need: {self.start_training_threshold}")
                    return {'status': 'ok', 'message': f'Waiting for more annotations ({self._get_annotation_count()}/{self.start_training_threshold})'}
                
                annotation = data.get('annotation', {})
                task = data.get('task', {})
                
                if not task or not annotation:
                    logger.error("Missing task or annotation data")
                    return {'status': 'error', 'message': 'Missing task or annotation data'}
                
                # Extract text and label
                text = task.get('data', {}).get('text', '')
                results = annotation.get('result', [])
                
                for result in results:
                    if result.get('type') == 'choices':
                        label = result.get('value', {}).get('choices', [])[0]
                        logger.info(f"Training on - Text: {text[:50]}... Label: {label}")
                        
                        try:
                            logger.info("Creating dataset...")
                            dataset = TextDataset(
                                texts=[text],
                                labels=[self.categories.index(label)],
                                tokenizer=self.tokenizer
                            )
                            train_loader = DataLoader(dataset, batch_size=1)
                            logger.info("βœ“ Dataset created")
                            
                            # Setup training
                            optimizer = AdamW(
                                self._model.parameters(),
                                lr=self.learning_rate,
                                weight_decay=self.weight_decay
                            )
                            self._model.train()
                            logger.info("Starting training...")
                            
                            # Training loop
                            total_loss = 0
                            for epoch in range(self.num_train_epochs):
                                logger.info(f"Starting epoch {epoch + 1}/{self.num_train_epochs}")
                                epoch_loss = 0
                                
                                for batch in train_loader:
                                    optimizer.zero_grad()
                                    
                                    # Move batch to device
                                    input_ids = batch['input_ids'].to(self.device)
                                    attention_mask = batch['attention_mask'].to(self.device)
                                    labels = batch['labels'].to(self.device)
                                    
                                    # Forward pass
                                    outputs = self._model(
                                        input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        labels=labels
                                    )
                                    
                                    loss = outputs.loss
                                    epoch_loss += loss.item()
                                    
                                    # Backward pass
                                    loss.backward()
                                    optimizer.step()
                                
                                avg_epoch_loss = epoch_loss / len(train_loader)
                                total_loss += avg_epoch_loss
                                logger.info(f"Epoch {epoch + 1} loss: {avg_epoch_loss:.4f}")
                            
                            avg_training_loss = total_loss / self.num_train_epochs
                            logger.info(f"Average training loss: {avg_training_loss:.4f}")
                            
                            # Save model
                            model_path = os.path.join(self.model_dir, 'model_state.pt')
                            torch.save(self._model.state_dict(), model_path)
                            logger.info(f"βœ“ Model saved to {model_path}")
                            
                            return {
                                'status': 'ok',
                                'message': f'Training completed with avg loss: {avg_training_loss:.4f}'
                            }
                                
                        except Exception as e:
                            logger.error(f"Training error: {str(e)}")
                            return {'status': 'error', 'message': str(e)}
    
        except Exception as e:
            logger.error(f"Error in fit method: {str(e)}")
            logger.error("Full error details:", exc_info=True)
            return {'status': 'error', 'message': str(e)}
        
        return {'status': 'ok', 'message': 'Event processed'}

    def _get_annotation_count(self):
        """Helper method to get the current annotation count"""
        # This is a placeholder - you'll need to implement actual counting
        # For now, returning 1 to allow immediate training
        return 1