Spaces:
Paused
Paused
Adding training initiation and logging
Browse files
model.py
CHANGED
|
@@ -7,6 +7,8 @@ import json
|
|
| 7 |
import torch
|
| 8 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 9 |
from label_studio_ml.model import LabelStudioMLBase, ModelResponse
|
|
|
|
|
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
|
@@ -180,26 +182,61 @@ Category:"""
|
|
| 180 |
return predictions
|
| 181 |
|
| 182 |
def fit(self, event, data, **kwargs):
|
| 183 |
-
"""Handle annotation events from Label Studio
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
event (str): Event type ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING')
|
| 187 |
-
data (dict): Event payload with annotation details
|
| 188 |
-
"""
|
| 189 |
|
| 190 |
valid_events = {'ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'}
|
| 191 |
if event not in valid_events:
|
| 192 |
logger.warning(f"Skip training: event {event} is not supported")
|
| 193 |
return
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
|
|
|
| 7 |
import torch
|
| 8 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 9 |
from label_studio_ml.model import LabelStudioMLBase, ModelResponse
|
| 10 |
+
from peft import get_peft_model, LoraConfig
|
| 11 |
+
import time
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
|
|
|
| 182 |
return predictions
|
| 183 |
|
| 184 |
def fit(self, event, data, **kwargs):
|
| 185 |
+
"""Handle annotation events from Label Studio"""
|
| 186 |
+
start_time = time.time()
|
| 187 |
+
logger.info("Starting training session...")
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
valid_events = {'ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'}
|
| 190 |
if event not in valid_events:
|
| 191 |
logger.warning(f"Skip training: event {event} is not supported")
|
| 192 |
return
|
| 193 |
|
| 194 |
+
try:
|
| 195 |
+
# Extract text and label
|
| 196 |
+
text = data['task']['data']['text']
|
| 197 |
+
label = data['annotation']['result'][0]['value']['choices'][0]
|
| 198 |
+
|
| 199 |
+
# Configure LoRA
|
| 200 |
+
lora_config = LoraConfig(
|
| 201 |
+
r=int(os.getenv('LORA_R', '8')),
|
| 202 |
+
lora_alpha=int(os.getenv('LORA_ALPHA', '32')),
|
| 203 |
+
target_modules=os.getenv('LORA_TARGET_MODULES', 'q,v').split(','),
|
| 204 |
+
lora_dropout=float(os.getenv('LORA_DROPOUT', '0.1')),
|
| 205 |
+
bias="none",
|
| 206 |
+
task_type="SEQ_2_SEQ_LM"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
logger.info("Preparing model for training...")
|
| 210 |
+
model = get_peft_model(self.model, lora_config)
|
| 211 |
+
model.print_trainable_parameters()
|
| 212 |
+
|
| 213 |
+
# Training loop
|
| 214 |
+
logger.info("Starting training loop...")
|
| 215 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=float(os.getenv('LEARNING_RATE', '1e-4')))
|
| 216 |
+
|
| 217 |
+
# Single training step for this annotation
|
| 218 |
+
model.train()
|
| 219 |
+
optimizer.zero_grad()
|
| 220 |
+
|
| 221 |
+
inputs = self.tokenizer(text, return_tensors="pt", max_length=self.max_length, truncation=True).to(self.device)
|
| 222 |
+
labels = self.tokenizer(label, return_tensors="pt", max_length=self.generation_max_length, truncation=True).to(self.device)
|
| 223 |
+
|
| 224 |
+
outputs = model(**inputs, labels=labels["input_ids"])
|
| 225 |
+
loss = outputs.loss
|
| 226 |
+
loss.backward()
|
| 227 |
+
optimizer.step()
|
| 228 |
+
|
| 229 |
+
logger.info(f"Training step completed. Loss: {loss.item():.4f}")
|
| 230 |
+
|
| 231 |
+
# Switch back to eval mode
|
| 232 |
+
model.eval()
|
| 233 |
+
|
| 234 |
+
training_time = time.time() - start_time
|
| 235 |
+
logger.info(f"Training session completed successfully in {training_time:.2f} seconds with tag: '{text}' and label: '{label}'")
|
| 236 |
+
|
| 237 |
+
except Exception as e:
|
| 238 |
+
training_time = time.time() - start_time
|
| 239 |
+
logger.error(f"Training failed after {training_time:.2f} seconds")
|
| 240 |
+
logger.error(f"Error during training: {str(e)}")
|
| 241 |
+
raise
|
| 242 |
|