File size: 14,040 Bytes
7e62685
 
5005caf
7e62685
30af0de
ed47222
7e62685
 
 
20eab11
3a11d9a
d44186e
5005caf
7e62685
 
 
571a6ec
 
 
 
 
 
 
 
 
a9b8d74
571a6ec
 
 
 
 
 
 
 
 
 
 
 
a9b8d74
 
 
 
 
 
3a11d9a
 
 
 
 
 
 
 
 
 
 
571a6ec
 
a9b8d74
 
 
571a6ec
 
 
 
 
 
5005caf
9f7842c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43f685f
 
 
 
 
 
 
 
4b117ce
 
 
 
 
 
 
 
 
 
 
 
 
 
5005caf
7e62685
30af0de
 
 
7e62685
4b117ce
51dd92c
4b117ce
 
 
51dd92c
 
 
724c5a8
 
 
5005caf
7e62685
 
4841fbd
7e62685
30af0de
7e62685
 
4b117ce
 
 
520dac2
 
4b117ce
520dac2
4b117ce
520dac2
 
 
 
 
4b117ce
520dac2
4b117ce
 
 
 
7e62685
4b117ce
7e62685
 
 
 
 
 
30af0de
7e62685
 
 
 
 
 
 
 
30af0de
7e62685
30af0de
7e62685
43f685f
724c5a8
43f685f
 
 
 
6e8b334
7e62685
 
 
4841fbd
7e62685
43f685f
7e62685
30af0de
603d158
6e8b334
30af0de
7e62685
30af0de
7e62685
30af0de
7e62685
 
30af0de
603d158
1b10a67
5005caf
d44186e
 
 
e8ef760
 
 
 
 
0bd8dce
d44186e
 
34fc450
 
 
 
 
 
 
 
 
 
d44186e
 
 
5b6ee0c
 
 
 
d44186e
 
 
 
 
 
 
 
92ec7e6
 
 
 
d44186e
 
5b6ee0c
d44186e
5b6ee0c
24a494f
334eca1
6c24fe0
 
 
 
 
 
 
 
 
 
 
4a32b85
 
 
 
 
 
 
 
 
 
 
 
d44186e
ed47222
 
 
 
 
 
 
 
f0d89f2
 
 
 
 
 
 
 
ed47222
 
 
 
 
 
72c90e7
 
 
 
 
 
 
 
 
3a11d9a
d44186e
 
 
 
 
 
 
 
 
 
5005caf
3a11d9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import os
import logging
from typing import List, Dict, Optional
from pathlib import Path
import json
from datetime import datetime

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from label_studio_ml.model import LabelStudioMLBase, ModelResponse
from peft import get_peft_model, LoraConfig, PeftModel
import time

logger = logging.getLogger(__name__)

class T5Model(LabelStudioMLBase):
    # Class-level configuration
    model_name = os.getenv('MODEL_NAME', 'google/flan-t5-base')
    max_length = int(os.getenv('MAX_LENGTH', '512'))
    generation_max_length = int(os.getenv('GENERATION_MAX_LENGTH', '128'))
    num_return_sequences = int(os.getenv('NUM_RETURN_SEQUENCES', '1'))
    
    # Model components (initialized as None)
    tokenizer = None
    model = None
    device = None  # Will be set during setup
    
    def setup(self):
        """Initialize the T5 model and parse configuration"""
        try:
            # Parse label config first
            text_config, choices_config = self.parse_config(self.label_config)
            self.from_name = choices_config.get('name')
            self.to_name = text_config.get('name')
            
            # Load tokenizer and model
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
            
            # Set device after model loading
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            if self.device == "cuda":
                self.model = self.model.cuda()
            
            # After initializing the base model, try to load the latest fine-tuned version
            latest_model_path = self.get_latest_model_path()
            if latest_model_path is not None:
                try:
                    logger.info(f"Loading latest model from {latest_model_path}")
                    self.model = PeftModel.from_pretrained(self.model, latest_model_path)
                    logger.info("Successfully loaded latest model")
                except Exception as e:
                    logger.error(f"Failed to load latest model: {str(e)}")
                    # Continue with base model if loading fails
            
            self.model.eval()
            
            logger.info(f"Using device: {self.device}")
            logger.info(f"Initialized with from_name={self.from_name}, to_name={self.to_name}")
            
            # Set initial model version
            self.set("model_version", "1.0.0")
            
        except Exception as e:
            logger.error(f"Error in model setup: {str(e)}")
            raise

    def parse_config(self, label_config):
        """Parse the label config to find nested elements"""
        import xml.etree.ElementTree as ET
        
        root = ET.fromstring(label_config)
        
        # Find Text and Choices tags anywhere in the tree
        text_tag = root.find('.//Text')
        choices_tag = root.find('.//Choices')
        
        text_config = text_tag.attrib if text_tag is not None else {}
        choices_config = choices_tag.attrib if choices_tag is not None else {}
        
        return text_config, choices_config

    def get_valid_choices(self, label_config):
        """Extract valid choice values from label config"""
        import xml.etree.ElementTree as ET
        
        root = ET.fromstring(label_config)
        choices = root.findall('.//Choice')
        return [choice.get('value') for choice in choices]

    def get_categories_with_hints(self, label_config):
        """Extract categories and their hints from label config"""
        import xml.etree.ElementTree as ET
        
        root = ET.fromstring(label_config)
        choices = root.findall('.//Choice')
        categories = []
        for choice in choices:
            categories.append({
                'value': choice.get('value'),
                'hint': choice.get('hint')
            })
        return categories

    def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
        """Generate predictions using T5 model"""
        logger.info("Received prediction request")
        logger.info(f"Tasks: {json.dumps(tasks, indent=2)}")
        
        predictions = []
        # Get categories with their descriptions
        try:
            categories = self.get_categories_with_hints(self.label_config)
            valid_choices = [cat['value'] for cat in categories]
            category_descriptions = [f"{cat['value']}: {cat['hint']}" for cat in categories]
            logger.info(f"Valid choices: {valid_choices}")
        except Exception as e:
            logger.error(f"Error parsing choices: {str(e)}")
            # TODO: remove this from all places once we have a valid choices
            valid_choices = ["other"]
            category_descriptions = ["other: Default category when no others apply"]
        
        try:
            for task in tasks:
                input_text = task['data'].get(self.to_name)
                if not input_text:
                    logger.warning(f"No input text found using {self.to_name}")
                    continue
                    
                # Format prompt with input text and category descriptions
                prompt = f"""Classify the following text into exactly one category.

                            Available categories with descriptions:
                            {chr(10).join(f"- {desc}" for desc in category_descriptions)}

                            Text to classify: {input_text}

                            Instructions:
                            1. Consider the text carefully
                            2. Choose the most appropriate category from the list
                            3. Return ONLY the category value (e.g. 'business_and_career', 'date', etc.)
                            4. Do not add any explanations or additional text

                            Category:"""
                
                logger.info(f"Generated prompt: {prompt}")
                
                # Generate prediction with prompt
                inputs = self.tokenizer(
                    prompt,
                    return_tensors="pt",
                    max_length=self.max_length,
                    truncation=True,
                    padding=True
                ).to(self.device)
                
                logger.info("Generating prediction...")
                with torch.no_grad():
                    outputs = self.model.generate(
                        **inputs,
                        max_length=self.generation_max_length,
                        num_return_sequences=self.num_return_sequences,
                        do_sample=True,
                        temperature=0.7
                    )
                
                predicted_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                logger.info(f"Generated prediction: {predicted_text}")
                
                # Find best matching choice
                best_choice = "other"  # default fallback
                if predicted_text in valid_choices:
                    best_choice = predicted_text
                
                # Format prediction with valid choice
                prediction = {
                    "result": [{
                        "from_name": self.from_name,
                        "to_name": self.to_name,
                        "type": "choices",
                        "value": {
                            "choices": [best_choice]
                        }
                    }],
                    "model_version": "1.0.0"
                }
                logger.info(f"Formatted prediction: {json.dumps(prediction, indent=2)}")
                predictions.append(prediction)

        except Exception as e:
            logger.error(f"Error in prediction: {str(e)}", exc_info=True)
            raise
        
        logger.info(f"Returning {len(predictions)} predictions")
        return predictions

    def fit(self, event, data, **kwargs):
        """Handle annotation events from Label Studio"""
        start_time = time.time()
        logger.info("Starting training session...")
        
        valid_events = {'ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'}
        if event not in valid_events:
            logger.warning(f"Skip training: event {event} is not supported")
            return
        
        try:
            # Extract text and label
            # LS sends two webhooks when training is initiated:
            # 1. contains all project data
            # 2. contains only the task data
            # We need to check which one is present and use the appropriate data
            if 'task' in data:
                text = data['task']['data']['text']
                label = data['annotation']['result'][0]['value']['choices'][0]
            else:
                logger.info("Skipping initial project setup webhook")
                return
            
            # Configure LoRA
            lora_config = LoraConfig(
                r=int(os.getenv('LORA_R', '4')),
                lora_alpha=int(os.getenv('LORA_ALPHA', '8')),
                target_modules=os.getenv('LORA_TARGET_MODULES', 'q,v').split(','),
                lora_dropout=float(os.getenv('LORA_DROPOUT', '0.1')),
                bias="none",
                task_type="SEQ_2_SEQ_LM"
            )
            
            logger.info("Preparing model for training...")
            model = get_peft_model(self.model, lora_config)
            model.print_trainable_parameters()
            
            # Tokenize inputs first
            inputs = self.tokenizer(text, return_tensors="pt", max_length=self.max_length, truncation=True).to(self.device)
            labels = self.tokenizer(label, return_tensors="pt", max_length=self.generation_max_length, truncation=True).to(self.device)
            
            # Training loop
            logger.info("Starting training loop...")
            optimizer = torch.optim.AdamW(model.parameters(), lr=float(os.getenv('LEARNING_RATE', '1e-5')))
            
            num_epochs = int(os.getenv('NUM_EPOCHS', '6'))
            

            # Add LoRA settings logging here
            logger.info("Current LoRA Configuration:")
            logger.info(f"  - Rank (r): {lora_config.r}")
            logger.info(f"  - Alpha: {lora_config.lora_alpha}")
            logger.info(f"  - Target Modules: {lora_config.target_modules}")
            logger.info(f"  - Dropout: {lora_config.lora_dropout}")
            logger.info(f"  - Learning Rate: {float(os.getenv('LEARNING_RATE', '1e-4'))}")
            logger.info(f"  - Number of Epochs: {num_epochs}")
            logger.info(f"  - Input text length: {len(inputs['input_ids'][0])} tokens")
            logger.info(f"  - Label length: {len(labels['input_ids'][0])} tokens")

            for epoch in range(num_epochs):
                logger.info(f"Starting epoch {epoch+1}/{num_epochs}")
                
                model.train()
                optimizer.zero_grad()
                
                outputs = model(**inputs, labels=labels["input_ids"])
                loss = outputs.loss
                loss.backward()
                optimizer.step()
                
                logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Loss: {loss.item():.4f}")
            
            # Save the model
            try:
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                model_dir = Path(os.getenv('MODEL_DIR', '/data/models'))
                model_dir.mkdir(parents=True, exist_ok=True)
                
                save_path = model_dir / f"model_{timestamp}"
                logger.info(f"Saving model to {save_path}")
                
                # Save the full model state
                model.save_pretrained(
                    save_path,
                    save_function=torch.save,
                    safe_serialization=True,
                    save_state_dict=True
                )
                logger.info(f"Model successfully saved to {save_path}")
                
            except Exception as e:
                logger.error(f"Failed to save model: {str(e)}")
                raise
            
            # Save the tokenizer
            try:
                logger.info(f"Saving tokenizer to {save_path}")
                self.tokenizer.save_pretrained(save_path)
                logger.info("Tokenizer successfully saved")
            except Exception as e:
                logger.error(f"Failed to save tokenizer: {str(e)}")
                raise
            
            # Switch to eval mode
            model.eval()
            
            training_time = time.time() - start_time
            logger.info(f"Training session completed successfully in {training_time:.2f} seconds with tag: '{text}' and label: '{label}'")
            
        except Exception as e:
            training_time = time.time() - start_time
            logger.error(f"Training failed after {training_time:.2f} seconds")
            logger.error(f"Error during training: {str(e)}")
            raise

    def get_latest_model_path(self) -> Path:
        """Get the path to the most recently saved model"""
        model_dir = Path(os.getenv('MODEL_DIR', '/data/models'))
        if not model_dir.exists():
            logger.warning(f"Model directory {model_dir} does not exist")
            return None
        
        # Find all model directories (they start with 'model_')
        model_paths = list(model_dir.glob("model_*"))
        if not model_paths:
            logger.warning("No saved models found")
            return None
        
        # Sort by creation time and get the most recent
        latest_model = max(model_paths, key=lambda x: x.stat().st_mtime)
        logger.info(f"Found latest model: {latest_model}")
        return latest_model