Spaces:
Paused
Paused
File size: 14,040 Bytes
7e62685 5005caf 7e62685 30af0de ed47222 7e62685 20eab11 3a11d9a d44186e 5005caf 7e62685 571a6ec a9b8d74 571a6ec a9b8d74 3a11d9a 571a6ec a9b8d74 571a6ec 5005caf 9f7842c 43f685f 4b117ce 5005caf 7e62685 30af0de 7e62685 4b117ce 51dd92c 4b117ce 51dd92c 724c5a8 5005caf 7e62685 4841fbd 7e62685 30af0de 7e62685 4b117ce 520dac2 4b117ce 520dac2 4b117ce 520dac2 4b117ce 520dac2 4b117ce 7e62685 4b117ce 7e62685 30af0de 7e62685 30af0de 7e62685 30af0de 7e62685 43f685f 724c5a8 43f685f 6e8b334 7e62685 4841fbd 7e62685 43f685f 7e62685 30af0de 603d158 6e8b334 30af0de 7e62685 30af0de 7e62685 30af0de 7e62685 30af0de 603d158 1b10a67 5005caf d44186e e8ef760 0bd8dce d44186e 34fc450 d44186e 5b6ee0c d44186e 92ec7e6 d44186e 5b6ee0c d44186e 5b6ee0c 24a494f 334eca1 6c24fe0 4a32b85 d44186e ed47222 f0d89f2 ed47222 72c90e7 3a11d9a d44186e 5005caf 3a11d9a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 | import os
import logging
from typing import List, Dict, Optional
from pathlib import Path
import json
from datetime import datetime
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from label_studio_ml.model import LabelStudioMLBase, ModelResponse
from peft import get_peft_model, LoraConfig, PeftModel
import time
logger = logging.getLogger(__name__)
class T5Model(LabelStudioMLBase):
# Class-level configuration
model_name = os.getenv('MODEL_NAME', 'google/flan-t5-base')
max_length = int(os.getenv('MAX_LENGTH', '512'))
generation_max_length = int(os.getenv('GENERATION_MAX_LENGTH', '128'))
num_return_sequences = int(os.getenv('NUM_RETURN_SEQUENCES', '1'))
# Model components (initialized as None)
tokenizer = None
model = None
device = None # Will be set during setup
def setup(self):
"""Initialize the T5 model and parse configuration"""
try:
# Parse label config first
text_config, choices_config = self.parse_config(self.label_config)
self.from_name = choices_config.get('name')
self.to_name = text_config.get('name')
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
# Set device after model loading
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if self.device == "cuda":
self.model = self.model.cuda()
# After initializing the base model, try to load the latest fine-tuned version
latest_model_path = self.get_latest_model_path()
if latest_model_path is not None:
try:
logger.info(f"Loading latest model from {latest_model_path}")
self.model = PeftModel.from_pretrained(self.model, latest_model_path)
logger.info("Successfully loaded latest model")
except Exception as e:
logger.error(f"Failed to load latest model: {str(e)}")
# Continue with base model if loading fails
self.model.eval()
logger.info(f"Using device: {self.device}")
logger.info(f"Initialized with from_name={self.from_name}, to_name={self.to_name}")
# Set initial model version
self.set("model_version", "1.0.0")
except Exception as e:
logger.error(f"Error in model setup: {str(e)}")
raise
def parse_config(self, label_config):
"""Parse the label config to find nested elements"""
import xml.etree.ElementTree as ET
root = ET.fromstring(label_config)
# Find Text and Choices tags anywhere in the tree
text_tag = root.find('.//Text')
choices_tag = root.find('.//Choices')
text_config = text_tag.attrib if text_tag is not None else {}
choices_config = choices_tag.attrib if choices_tag is not None else {}
return text_config, choices_config
def get_valid_choices(self, label_config):
"""Extract valid choice values from label config"""
import xml.etree.ElementTree as ET
root = ET.fromstring(label_config)
choices = root.findall('.//Choice')
return [choice.get('value') for choice in choices]
def get_categories_with_hints(self, label_config):
"""Extract categories and their hints from label config"""
import xml.etree.ElementTree as ET
root = ET.fromstring(label_config)
choices = root.findall('.//Choice')
categories = []
for choice in choices:
categories.append({
'value': choice.get('value'),
'hint': choice.get('hint')
})
return categories
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
"""Generate predictions using T5 model"""
logger.info("Received prediction request")
logger.info(f"Tasks: {json.dumps(tasks, indent=2)}")
predictions = []
# Get categories with their descriptions
try:
categories = self.get_categories_with_hints(self.label_config)
valid_choices = [cat['value'] for cat in categories]
category_descriptions = [f"{cat['value']}: {cat['hint']}" for cat in categories]
logger.info(f"Valid choices: {valid_choices}")
except Exception as e:
logger.error(f"Error parsing choices: {str(e)}")
# TODO: remove this from all places once we have a valid choices
valid_choices = ["other"]
category_descriptions = ["other: Default category when no others apply"]
try:
for task in tasks:
input_text = task['data'].get(self.to_name)
if not input_text:
logger.warning(f"No input text found using {self.to_name}")
continue
# Format prompt with input text and category descriptions
prompt = f"""Classify the following text into exactly one category.
Available categories with descriptions:
{chr(10).join(f"- {desc}" for desc in category_descriptions)}
Text to classify: {input_text}
Instructions:
1. Consider the text carefully
2. Choose the most appropriate category from the list
3. Return ONLY the category value (e.g. 'business_and_career', 'date', etc.)
4. Do not add any explanations or additional text
Category:"""
logger.info(f"Generated prompt: {prompt}")
# Generate prediction with prompt
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=self.max_length,
truncation=True,
padding=True
).to(self.device)
logger.info("Generating prediction...")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=self.generation_max_length,
num_return_sequences=self.num_return_sequences,
do_sample=True,
temperature=0.7
)
predicted_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Generated prediction: {predicted_text}")
# Find best matching choice
best_choice = "other" # default fallback
if predicted_text in valid_choices:
best_choice = predicted_text
# Format prediction with valid choice
prediction = {
"result": [{
"from_name": self.from_name,
"to_name": self.to_name,
"type": "choices",
"value": {
"choices": [best_choice]
}
}],
"model_version": "1.0.0"
}
logger.info(f"Formatted prediction: {json.dumps(prediction, indent=2)}")
predictions.append(prediction)
except Exception as e:
logger.error(f"Error in prediction: {str(e)}", exc_info=True)
raise
logger.info(f"Returning {len(predictions)} predictions")
return predictions
def fit(self, event, data, **kwargs):
"""Handle annotation events from Label Studio"""
start_time = time.time()
logger.info("Starting training session...")
valid_events = {'ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'}
if event not in valid_events:
logger.warning(f"Skip training: event {event} is not supported")
return
try:
# Extract text and label
# LS sends two webhooks when training is initiated:
# 1. contains all project data
# 2. contains only the task data
# We need to check which one is present and use the appropriate data
if 'task' in data:
text = data['task']['data']['text']
label = data['annotation']['result'][0]['value']['choices'][0]
else:
logger.info("Skipping initial project setup webhook")
return
# Configure LoRA
lora_config = LoraConfig(
r=int(os.getenv('LORA_R', '4')),
lora_alpha=int(os.getenv('LORA_ALPHA', '8')),
target_modules=os.getenv('LORA_TARGET_MODULES', 'q,v').split(','),
lora_dropout=float(os.getenv('LORA_DROPOUT', '0.1')),
bias="none",
task_type="SEQ_2_SEQ_LM"
)
logger.info("Preparing model for training...")
model = get_peft_model(self.model, lora_config)
model.print_trainable_parameters()
# Tokenize inputs first
inputs = self.tokenizer(text, return_tensors="pt", max_length=self.max_length, truncation=True).to(self.device)
labels = self.tokenizer(label, return_tensors="pt", max_length=self.generation_max_length, truncation=True).to(self.device)
# Training loop
logger.info("Starting training loop...")
optimizer = torch.optim.AdamW(model.parameters(), lr=float(os.getenv('LEARNING_RATE', '1e-5')))
num_epochs = int(os.getenv('NUM_EPOCHS', '6'))
# Add LoRA settings logging here
logger.info("Current LoRA Configuration:")
logger.info(f" - Rank (r): {lora_config.r}")
logger.info(f" - Alpha: {lora_config.lora_alpha}")
logger.info(f" - Target Modules: {lora_config.target_modules}")
logger.info(f" - Dropout: {lora_config.lora_dropout}")
logger.info(f" - Learning Rate: {float(os.getenv('LEARNING_RATE', '1e-4'))}")
logger.info(f" - Number of Epochs: {num_epochs}")
logger.info(f" - Input text length: {len(inputs['input_ids'][0])} tokens")
logger.info(f" - Label length: {len(labels['input_ids'][0])} tokens")
for epoch in range(num_epochs):
logger.info(f"Starting epoch {epoch+1}/{num_epochs}")
model.train()
optimizer.zero_grad()
outputs = model(**inputs, labels=labels["input_ids"])
loss = outputs.loss
loss.backward()
optimizer.step()
logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Loss: {loss.item():.4f}")
# Save the model
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = Path(os.getenv('MODEL_DIR', '/data/models'))
model_dir.mkdir(parents=True, exist_ok=True)
save_path = model_dir / f"model_{timestamp}"
logger.info(f"Saving model to {save_path}")
# Save the full model state
model.save_pretrained(
save_path,
save_function=torch.save,
safe_serialization=True,
save_state_dict=True
)
logger.info(f"Model successfully saved to {save_path}")
except Exception as e:
logger.error(f"Failed to save model: {str(e)}")
raise
# Save the tokenizer
try:
logger.info(f"Saving tokenizer to {save_path}")
self.tokenizer.save_pretrained(save_path)
logger.info("Tokenizer successfully saved")
except Exception as e:
logger.error(f"Failed to save tokenizer: {str(e)}")
raise
# Switch to eval mode
model.eval()
training_time = time.time() - start_time
logger.info(f"Training session completed successfully in {training_time:.2f} seconds with tag: '{text}' and label: '{label}'")
except Exception as e:
training_time = time.time() - start_time
logger.error(f"Training failed after {training_time:.2f} seconds")
logger.error(f"Error during training: {str(e)}")
raise
def get_latest_model_path(self) -> Path:
"""Get the path to the most recently saved model"""
model_dir = Path(os.getenv('MODEL_DIR', '/data/models'))
if not model_dir.exists():
logger.warning(f"Model directory {model_dir} does not exist")
return None
# Find all model directories (they start with 'model_')
model_paths = list(model_dir.glob("model_*"))
if not model_paths:
logger.warning("No saved models found")
return None
# Sort by creation time and get the most recent
latest_model = max(model_paths, key=lambda x: x.stat().st_mtime)
logger.info(f"Found latest model: {latest_model}")
return latest_model
|