ls_be_T5_base / model.py
b2u's picture
save full state dict for the model
f0d89f2
import os
import logging
from typing import List, Dict, Optional
from pathlib import Path
import json
from datetime import datetime
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from label_studio_ml.model import LabelStudioMLBase, ModelResponse
from peft import get_peft_model, LoraConfig, PeftModel
import time
logger = logging.getLogger(__name__)
class T5Model(LabelStudioMLBase):
# Class-level configuration
model_name = os.getenv('MODEL_NAME', 'google/flan-t5-base')
max_length = int(os.getenv('MAX_LENGTH', '512'))
generation_max_length = int(os.getenv('GENERATION_MAX_LENGTH', '128'))
num_return_sequences = int(os.getenv('NUM_RETURN_SEQUENCES', '1'))
# Model components (initialized as None)
tokenizer = None
model = None
device = None # Will be set during setup
def setup(self):
"""Initialize the T5 model and parse configuration"""
try:
# Parse label config first
text_config, choices_config = self.parse_config(self.label_config)
self.from_name = choices_config.get('name')
self.to_name = text_config.get('name')
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
# Set device after model loading
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if self.device == "cuda":
self.model = self.model.cuda()
# After initializing the base model, try to load the latest fine-tuned version
latest_model_path = self.get_latest_model_path()
if latest_model_path is not None:
try:
logger.info(f"Loading latest model from {latest_model_path}")
self.model = PeftModel.from_pretrained(self.model, latest_model_path)
logger.info("Successfully loaded latest model")
except Exception as e:
logger.error(f"Failed to load latest model: {str(e)}")
# Continue with base model if loading fails
self.model.eval()
logger.info(f"Using device: {self.device}")
logger.info(f"Initialized with from_name={self.from_name}, to_name={self.to_name}")
# Set initial model version
self.set("model_version", "1.0.0")
except Exception as e:
logger.error(f"Error in model setup: {str(e)}")
raise
def parse_config(self, label_config):
"""Parse the label config to find nested elements"""
import xml.etree.ElementTree as ET
root = ET.fromstring(label_config)
# Find Text and Choices tags anywhere in the tree
text_tag = root.find('.//Text')
choices_tag = root.find('.//Choices')
text_config = text_tag.attrib if text_tag is not None else {}
choices_config = choices_tag.attrib if choices_tag is not None else {}
return text_config, choices_config
def get_valid_choices(self, label_config):
"""Extract valid choice values from label config"""
import xml.etree.ElementTree as ET
root = ET.fromstring(label_config)
choices = root.findall('.//Choice')
return [choice.get('value') for choice in choices]
def get_categories_with_hints(self, label_config):
"""Extract categories and their hints from label config"""
import xml.etree.ElementTree as ET
root = ET.fromstring(label_config)
choices = root.findall('.//Choice')
categories = []
for choice in choices:
categories.append({
'value': choice.get('value'),
'hint': choice.get('hint')
})
return categories
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
"""Generate predictions using T5 model"""
logger.info("Received prediction request")
logger.info(f"Tasks: {json.dumps(tasks, indent=2)}")
predictions = []
# Get categories with their descriptions
try:
categories = self.get_categories_with_hints(self.label_config)
valid_choices = [cat['value'] for cat in categories]
category_descriptions = [f"{cat['value']}: {cat['hint']}" for cat in categories]
logger.info(f"Valid choices: {valid_choices}")
except Exception as e:
logger.error(f"Error parsing choices: {str(e)}")
# TODO: remove this from all places once we have a valid choices
valid_choices = ["other"]
category_descriptions = ["other: Default category when no others apply"]
try:
for task in tasks:
input_text = task['data'].get(self.to_name)
if not input_text:
logger.warning(f"No input text found using {self.to_name}")
continue
# Format prompt with input text and category descriptions
prompt = f"""Classify the following text into exactly one category.
Available categories with descriptions:
{chr(10).join(f"- {desc}" for desc in category_descriptions)}
Text to classify: {input_text}
Instructions:
1. Consider the text carefully
2. Choose the most appropriate category from the list
3. Return ONLY the category value (e.g. 'business_and_career', 'date', etc.)
4. Do not add any explanations or additional text
Category:"""
logger.info(f"Generated prompt: {prompt}")
# Generate prediction with prompt
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=self.max_length,
truncation=True,
padding=True
).to(self.device)
logger.info("Generating prediction...")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=self.generation_max_length,
num_return_sequences=self.num_return_sequences,
do_sample=True,
temperature=0.7
)
predicted_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Generated prediction: {predicted_text}")
# Find best matching choice
best_choice = "other" # default fallback
if predicted_text in valid_choices:
best_choice = predicted_text
# Format prediction with valid choice
prediction = {
"result": [{
"from_name": self.from_name,
"to_name": self.to_name,
"type": "choices",
"value": {
"choices": [best_choice]
}
}],
"model_version": "1.0.0"
}
logger.info(f"Formatted prediction: {json.dumps(prediction, indent=2)}")
predictions.append(prediction)
except Exception as e:
logger.error(f"Error in prediction: {str(e)}", exc_info=True)
raise
logger.info(f"Returning {len(predictions)} predictions")
return predictions
def fit(self, event, data, **kwargs):
"""Handle annotation events from Label Studio"""
start_time = time.time()
logger.info("Starting training session...")
valid_events = {'ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'}
if event not in valid_events:
logger.warning(f"Skip training: event {event} is not supported")
return
try:
# Extract text and label
# LS sends two webhooks when training is initiated:
# 1. contains all project data
# 2. contains only the task data
# We need to check which one is present and use the appropriate data
if 'task' in data:
text = data['task']['data']['text']
label = data['annotation']['result'][0]['value']['choices'][0]
else:
logger.info("Skipping initial project setup webhook")
return
# Configure LoRA
lora_config = LoraConfig(
r=int(os.getenv('LORA_R', '4')),
lora_alpha=int(os.getenv('LORA_ALPHA', '8')),
target_modules=os.getenv('LORA_TARGET_MODULES', 'q,v').split(','),
lora_dropout=float(os.getenv('LORA_DROPOUT', '0.1')),
bias="none",
task_type="SEQ_2_SEQ_LM"
)
logger.info("Preparing model for training...")
model = get_peft_model(self.model, lora_config)
model.print_trainable_parameters()
# Tokenize inputs first
inputs = self.tokenizer(text, return_tensors="pt", max_length=self.max_length, truncation=True).to(self.device)
labels = self.tokenizer(label, return_tensors="pt", max_length=self.generation_max_length, truncation=True).to(self.device)
# Training loop
logger.info("Starting training loop...")
optimizer = torch.optim.AdamW(model.parameters(), lr=float(os.getenv('LEARNING_RATE', '1e-5')))
num_epochs = int(os.getenv('NUM_EPOCHS', '6'))
# Add LoRA settings logging here
logger.info("Current LoRA Configuration:")
logger.info(f" - Rank (r): {lora_config.r}")
logger.info(f" - Alpha: {lora_config.lora_alpha}")
logger.info(f" - Target Modules: {lora_config.target_modules}")
logger.info(f" - Dropout: {lora_config.lora_dropout}")
logger.info(f" - Learning Rate: {float(os.getenv('LEARNING_RATE', '1e-4'))}")
logger.info(f" - Number of Epochs: {num_epochs}")
logger.info(f" - Input text length: {len(inputs['input_ids'][0])} tokens")
logger.info(f" - Label length: {len(labels['input_ids'][0])} tokens")
for epoch in range(num_epochs):
logger.info(f"Starting epoch {epoch+1}/{num_epochs}")
model.train()
optimizer.zero_grad()
outputs = model(**inputs, labels=labels["input_ids"])
loss = outputs.loss
loss.backward()
optimizer.step()
logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Loss: {loss.item():.4f}")
# Save the model
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = Path(os.getenv('MODEL_DIR', '/data/models'))
model_dir.mkdir(parents=True, exist_ok=True)
save_path = model_dir / f"model_{timestamp}"
logger.info(f"Saving model to {save_path}")
# Save the full model state
model.save_pretrained(
save_path,
save_function=torch.save,
safe_serialization=True,
save_state_dict=True
)
logger.info(f"Model successfully saved to {save_path}")
except Exception as e:
logger.error(f"Failed to save model: {str(e)}")
raise
# Save the tokenizer
try:
logger.info(f"Saving tokenizer to {save_path}")
self.tokenizer.save_pretrained(save_path)
logger.info("Tokenizer successfully saved")
except Exception as e:
logger.error(f"Failed to save tokenizer: {str(e)}")
raise
# Switch to eval mode
model.eval()
training_time = time.time() - start_time
logger.info(f"Training session completed successfully in {training_time:.2f} seconds with tag: '{text}' and label: '{label}'")
except Exception as e:
training_time = time.time() - start_time
logger.error(f"Training failed after {training_time:.2f} seconds")
logger.error(f"Error during training: {str(e)}")
raise
def get_latest_model_path(self) -> Path:
"""Get the path to the most recently saved model"""
model_dir = Path(os.getenv('MODEL_DIR', '/data/models'))
if not model_dir.exists():
logger.warning(f"Model directory {model_dir} does not exist")
return None
# Find all model directories (they start with 'model_')
model_paths = list(model_dir.glob("model_*"))
if not model_paths:
logger.warning("No saved models found")
return None
# Sort by creation time and get the most recent
latest_model = max(model_paths, key=lambda x: x.stat().st_mtime)
logger.info(f"Found latest model: {latest_model}")
return latest_model