import os import logging from typing import List, Dict, Optional from pathlib import Path import json from datetime import datetime import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from label_studio_ml.model import LabelStudioMLBase, ModelResponse from peft import get_peft_model, LoraConfig, PeftModel import time logger = logging.getLogger(__name__) class T5Model(LabelStudioMLBase): # Class-level configuration model_name = os.getenv('MODEL_NAME', 'google/flan-t5-base') max_length = int(os.getenv('MAX_LENGTH', '512')) generation_max_length = int(os.getenv('GENERATION_MAX_LENGTH', '128')) num_return_sequences = int(os.getenv('NUM_RETURN_SEQUENCES', '1')) # Model components (initialized as None) tokenizer = None model = None device = None # Will be set during setup def setup(self): """Initialize the T5 model and parse configuration""" try: # Parse label config first text_config, choices_config = self.parse_config(self.label_config) self.from_name = choices_config.get('name') self.to_name = text_config.get('name') # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) # Set device after model loading self.device = "cuda" if torch.cuda.is_available() else "cpu" if self.device == "cuda": self.model = self.model.cuda() # After initializing the base model, try to load the latest fine-tuned version latest_model_path = self.get_latest_model_path() if latest_model_path is not None: try: logger.info(f"Loading latest model from {latest_model_path}") self.model = PeftModel.from_pretrained(self.model, latest_model_path) logger.info("Successfully loaded latest model") except Exception as e: logger.error(f"Failed to load latest model: {str(e)}") # Continue with base model if loading fails self.model.eval() logger.info(f"Using device: {self.device}") logger.info(f"Initialized with from_name={self.from_name}, to_name={self.to_name}") # Set initial model version self.set("model_version", "1.0.0") except Exception as e: logger.error(f"Error in model setup: {str(e)}") raise def parse_config(self, label_config): """Parse the label config to find nested elements""" import xml.etree.ElementTree as ET root = ET.fromstring(label_config) # Find Text and Choices tags anywhere in the tree text_tag = root.find('.//Text') choices_tag = root.find('.//Choices') text_config = text_tag.attrib if text_tag is not None else {} choices_config = choices_tag.attrib if choices_tag is not None else {} return text_config, choices_config def get_valid_choices(self, label_config): """Extract valid choice values from label config""" import xml.etree.ElementTree as ET root = ET.fromstring(label_config) choices = root.findall('.//Choice') return [choice.get('value') for choice in choices] def get_categories_with_hints(self, label_config): """Extract categories and their hints from label config""" import xml.etree.ElementTree as ET root = ET.fromstring(label_config) choices = root.findall('.//Choice') categories = [] for choice in choices: categories.append({ 'value': choice.get('value'), 'hint': choice.get('hint') }) return categories def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse: """Generate predictions using T5 model""" logger.info("Received prediction request") logger.info(f"Tasks: {json.dumps(tasks, indent=2)}") predictions = [] # Get categories with their descriptions try: categories = self.get_categories_with_hints(self.label_config) valid_choices = [cat['value'] for cat in categories] category_descriptions = [f"{cat['value']}: {cat['hint']}" for cat in categories] logger.info(f"Valid choices: {valid_choices}") except Exception as e: logger.error(f"Error parsing choices: {str(e)}") # TODO: remove this from all places once we have a valid choices valid_choices = ["other"] category_descriptions = ["other: Default category when no others apply"] try: for task in tasks: input_text = task['data'].get(self.to_name) if not input_text: logger.warning(f"No input text found using {self.to_name}") continue # Format prompt with input text and category descriptions prompt = f"""Classify the following text into exactly one category. Available categories with descriptions: {chr(10).join(f"- {desc}" for desc in category_descriptions)} Text to classify: {input_text} Instructions: 1. Consider the text carefully 2. Choose the most appropriate category from the list 3. Return ONLY the category value (e.g. 'business_and_career', 'date', etc.) 4. Do not add any explanations or additional text Category:""" logger.info(f"Generated prompt: {prompt}") # Generate prediction with prompt inputs = self.tokenizer( prompt, return_tensors="pt", max_length=self.max_length, truncation=True, padding=True ).to(self.device) logger.info("Generating prediction...") with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=self.generation_max_length, num_return_sequences=self.num_return_sequences, do_sample=True, temperature=0.7 ) predicted_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) logger.info(f"Generated prediction: {predicted_text}") # Find best matching choice best_choice = "other" # default fallback if predicted_text in valid_choices: best_choice = predicted_text # Format prediction with valid choice prediction = { "result": [{ "from_name": self.from_name, "to_name": self.to_name, "type": "choices", "value": { "choices": [best_choice] } }], "model_version": "1.0.0" } logger.info(f"Formatted prediction: {json.dumps(prediction, indent=2)}") predictions.append(prediction) except Exception as e: logger.error(f"Error in prediction: {str(e)}", exc_info=True) raise logger.info(f"Returning {len(predictions)} predictions") return predictions def fit(self, event, data, **kwargs): """Handle annotation events from Label Studio""" start_time = time.time() logger.info("Starting training session...") valid_events = {'ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'} if event not in valid_events: logger.warning(f"Skip training: event {event} is not supported") return try: # Extract text and label # LS sends two webhooks when training is initiated: # 1. contains all project data # 2. contains only the task data # We need to check which one is present and use the appropriate data if 'task' in data: text = data['task']['data']['text'] label = data['annotation']['result'][0]['value']['choices'][0] else: logger.info("Skipping initial project setup webhook") return # Configure LoRA lora_config = LoraConfig( r=int(os.getenv('LORA_R', '4')), lora_alpha=int(os.getenv('LORA_ALPHA', '8')), target_modules=os.getenv('LORA_TARGET_MODULES', 'q,v').split(','), lora_dropout=float(os.getenv('LORA_DROPOUT', '0.1')), bias="none", task_type="SEQ_2_SEQ_LM" ) logger.info("Preparing model for training...") model = get_peft_model(self.model, lora_config) model.print_trainable_parameters() # Tokenize inputs first inputs = self.tokenizer(text, return_tensors="pt", max_length=self.max_length, truncation=True).to(self.device) labels = self.tokenizer(label, return_tensors="pt", max_length=self.generation_max_length, truncation=True).to(self.device) # Training loop logger.info("Starting training loop...") optimizer = torch.optim.AdamW(model.parameters(), lr=float(os.getenv('LEARNING_RATE', '1e-5'))) num_epochs = int(os.getenv('NUM_EPOCHS', '6')) # Add LoRA settings logging here logger.info("Current LoRA Configuration:") logger.info(f" - Rank (r): {lora_config.r}") logger.info(f" - Alpha: {lora_config.lora_alpha}") logger.info(f" - Target Modules: {lora_config.target_modules}") logger.info(f" - Dropout: {lora_config.lora_dropout}") logger.info(f" - Learning Rate: {float(os.getenv('LEARNING_RATE', '1e-4'))}") logger.info(f" - Number of Epochs: {num_epochs}") logger.info(f" - Input text length: {len(inputs['input_ids'][0])} tokens") logger.info(f" - Label length: {len(labels['input_ids'][0])} tokens") for epoch in range(num_epochs): logger.info(f"Starting epoch {epoch+1}/{num_epochs}") model.train() optimizer.zero_grad() outputs = model(**inputs, labels=labels["input_ids"]) loss = outputs.loss loss.backward() optimizer.step() logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Loss: {loss.item():.4f}") # Save the model try: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_dir = Path(os.getenv('MODEL_DIR', '/data/models')) model_dir.mkdir(parents=True, exist_ok=True) save_path = model_dir / f"model_{timestamp}" logger.info(f"Saving model to {save_path}") # Save the full model state model.save_pretrained( save_path, save_function=torch.save, safe_serialization=True, save_state_dict=True ) logger.info(f"Model successfully saved to {save_path}") except Exception as e: logger.error(f"Failed to save model: {str(e)}") raise # Save the tokenizer try: logger.info(f"Saving tokenizer to {save_path}") self.tokenizer.save_pretrained(save_path) logger.info("Tokenizer successfully saved") except Exception as e: logger.error(f"Failed to save tokenizer: {str(e)}") raise # Switch to eval mode model.eval() training_time = time.time() - start_time logger.info(f"Training session completed successfully in {training_time:.2f} seconds with tag: '{text}' and label: '{label}'") except Exception as e: training_time = time.time() - start_time logger.error(f"Training failed after {training_time:.2f} seconds") logger.error(f"Error during training: {str(e)}") raise def get_latest_model_path(self) -> Path: """Get the path to the most recently saved model""" model_dir = Path(os.getenv('MODEL_DIR', '/data/models')) if not model_dir.exists(): logger.warning(f"Model directory {model_dir} does not exist") return None # Find all model directories (they start with 'model_') model_paths = list(model_dir.glob("model_*")) if not model_paths: logger.warning("No saved models found") return None # Sort by creation time and get the most recent latest_model = max(model_paths, key=lambda x: x.stat().st_mtime) logger.info(f"Found latest model: {latest_model}") return latest_model