Spaces:
Sleeping
Sleeping
adding proper training cycle handling
Browse files
model.py
CHANGED
|
@@ -17,13 +17,6 @@ current_dir = Path(__file__).parent
|
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
-
# Add these debug lines
|
| 21 |
-
logger.info("=== DEBUG INFO ===")
|
| 22 |
-
logger.info(f"Python path: {sys.path}")
|
| 23 |
-
logger.info(f"Current directory: {os.getcwd()}")
|
| 24 |
-
logger.info(f"Directory contents: {os.listdir('.')}")
|
| 25 |
-
logger.info("=== END DEBUG INFO ===")
|
| 26 |
-
|
| 27 |
# Move TextDataset class here
|
| 28 |
class TextDataset(Dataset):
|
| 29 |
def __init__(self, texts, labels, tokenizer, max_length=128):
|
|
@@ -43,10 +36,10 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 43 |
super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
|
| 44 |
|
| 45 |
# Load training configuration from environment variables
|
| 46 |
-
self.learning_rate = float(os.getenv('LEARNING_RATE'
|
| 47 |
-
self.num_train_epochs = int(os.getenv('NUM_TRAIN_EPOCHS'
|
| 48 |
-
self.weight_decay = float(os.getenv('WEIGHT_DECAY'
|
| 49 |
-
self.start_training_threshold = int(os.getenv('START_TRAINING_EACH_N_UPDATES'
|
| 50 |
|
| 51 |
logger.info("=== Training Configuration ===")
|
| 52 |
logger.info(f"β Learning rate: {self.learning_rate}")
|
|
@@ -55,38 +48,28 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 55 |
logger.info(f"β Training threshold: {self.start_training_threshold}")
|
| 56 |
logger.info("============================")
|
| 57 |
|
| 58 |
-
|
| 59 |
-
logger.info(f"Label config length: {len(label_config) if label_config else 0}")
|
| 60 |
-
|
| 61 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 62 |
logger.info(f"Using device: {self.device}")
|
| 63 |
|
| 64 |
-
#
|
| 65 |
-
self.
|
| 66 |
-
'
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
'performance', 'post_type', 'pricing_tier',
|
| 71 |
-
'product', 'profession', 'pii', 'social_network',
|
| 72 |
-
'style_and_fashion', 'no_category'
|
| 73 |
-
]
|
| 74 |
|
| 75 |
-
|
| 76 |
-
os.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
self._model = AutoModelForSequenceClassification.from_pretrained(
|
| 81 |
-
'bert-base-uncased',
|
| 82 |
-
num_labels=len(self.categories)
|
| 83 |
-
)
|
| 84 |
-
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
| 85 |
-
self._model.to(self.device)
|
| 86 |
-
logger.info("Successfully loaded BERT model and tokenizer")
|
| 87 |
-
except Exception as e:
|
| 88 |
-
logger.error(f"Error loading model: {str(e)}")
|
| 89 |
-
logger.error("Full error details:", exc_info=True)
|
| 90 |
|
| 91 |
def predict(self, tasks, **kwargs):
|
| 92 |
"""Generate predictions for a list of tasks."""
|
|
@@ -150,7 +133,6 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 150 |
return predictions
|
| 151 |
|
| 152 |
def fit(self, event_data, data=None, **kwargs):
|
| 153 |
-
"""Train the model on a single annotation."""
|
| 154 |
start_time = datetime.now()
|
| 155 |
logger.info("=== FIT METHOD CALLED ===")
|
| 156 |
|
|
@@ -196,11 +178,12 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 196 |
self._model.train()
|
| 197 |
logger.info("Starting training...")
|
| 198 |
|
| 199 |
-
#
|
|
|
|
| 200 |
for epoch in range(self.num_train_epochs):
|
| 201 |
logger.info(f"Starting epoch {epoch + 1}/{self.num_train_epochs}")
|
|
|
|
| 202 |
|
| 203 |
-
# Single example training
|
| 204 |
for batch in train_loader:
|
| 205 |
optimizer.zero_grad()
|
| 206 |
|
|
@@ -217,27 +200,32 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 217 |
)
|
| 218 |
|
| 219 |
loss = outputs.loss
|
| 220 |
-
|
| 221 |
|
| 222 |
# Backward pass
|
| 223 |
loss.backward()
|
| 224 |
optimizer.step()
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
except Exception as e:
|
| 238 |
logger.error(f"Training error: {str(e)}")
|
| 239 |
-
|
| 240 |
-
return {'status': 'error', 'message': f'Training failed: {str(e)}'}
|
| 241 |
|
| 242 |
except Exception as e:
|
| 243 |
logger.error(f"Error in fit method: {str(e)}")
|
|
|
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# Move TextDataset class here
|
| 21 |
class TextDataset(Dataset):
|
| 22 |
def __init__(self, texts, labels, tokenizer, max_length=128):
|
|
|
|
| 36 |
super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
|
| 37 |
|
| 38 |
# Load training configuration from environment variables
|
| 39 |
+
self.learning_rate = float(os.getenv('LEARNING_RATE'))
|
| 40 |
+
self.num_train_epochs = int(os.getenv('NUM_TRAIN_EPOCHS'))
|
| 41 |
+
self.weight_decay = float(os.getenv('WEIGHT_DECAY'))
|
| 42 |
+
self.start_training_threshold = int(os.getenv('START_TRAINING_EACH_N_UPDATES'))
|
| 43 |
|
| 44 |
logger.info("=== Training Configuration ===")
|
| 45 |
logger.info(f"β Learning rate: {self.learning_rate}")
|
|
|
|
| 48 |
logger.info(f"β Training threshold: {self.start_training_threshold}")
|
| 49 |
logger.info("============================")
|
| 50 |
|
| 51 |
+
# Initialize model and move to device
|
|
|
|
|
|
|
| 52 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 53 |
logger.info(f"Using device: {self.device}")
|
| 54 |
|
| 55 |
+
# Initialize model
|
| 56 |
+
self._model = AutoModelForSequenceClassification.from_pretrained(
|
| 57 |
+
'bert-base-uncased',
|
| 58 |
+
num_labels=len(self.categories)
|
| 59 |
+
)
|
| 60 |
+
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
# Load saved model if exists
|
| 63 |
+
model_path = os.path.join(self.model_dir, 'model_state.pt')
|
| 64 |
+
if os.path.exists(model_path):
|
| 65 |
+
try:
|
| 66 |
+
self._model.load_state_dict(torch.load(model_path))
|
| 67 |
+
logger.info(f"β Loaded saved model from {model_path}")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"Failed to load model: {str(e)}")
|
| 70 |
|
| 71 |
+
self._model.to(self.device)
|
| 72 |
+
logger.info("β Model ready")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def predict(self, tasks, **kwargs):
|
| 75 |
"""Generate predictions for a list of tasks."""
|
|
|
|
| 133 |
return predictions
|
| 134 |
|
| 135 |
def fit(self, event_data, data=None, **kwargs):
|
|
|
|
| 136 |
start_time = datetime.now()
|
| 137 |
logger.info("=== FIT METHOD CALLED ===")
|
| 138 |
|
|
|
|
| 178 |
self._model.train()
|
| 179 |
logger.info("Starting training...")
|
| 180 |
|
| 181 |
+
# Training loop
|
| 182 |
+
total_loss = 0
|
| 183 |
for epoch in range(self.num_train_epochs):
|
| 184 |
logger.info(f"Starting epoch {epoch + 1}/{self.num_train_epochs}")
|
| 185 |
+
epoch_loss = 0
|
| 186 |
|
|
|
|
| 187 |
for batch in train_loader:
|
| 188 |
optimizer.zero_grad()
|
| 189 |
|
|
|
|
| 200 |
)
|
| 201 |
|
| 202 |
loss = outputs.loss
|
| 203 |
+
epoch_loss += loss.item()
|
| 204 |
|
| 205 |
# Backward pass
|
| 206 |
loss.backward()
|
| 207 |
optimizer.step()
|
| 208 |
|
| 209 |
+
avg_epoch_loss = epoch_loss / len(train_loader)
|
| 210 |
+
total_loss += avg_epoch_loss
|
| 211 |
+
logger.info(f"Epoch {epoch + 1} loss: {avg_epoch_loss:.4f}")
|
| 212 |
+
|
| 213 |
+
avg_training_loss = total_loss / self.num_train_epochs
|
| 214 |
+
logger.info(f"Average training loss: {avg_training_loss:.4f}")
|
| 215 |
+
|
| 216 |
+
# Save model
|
| 217 |
+
model_path = os.path.join(self.model_dir, 'model_state.pt')
|
| 218 |
+
torch.save(self._model.state_dict(), model_path)
|
| 219 |
+
logger.info(f"β Model saved to {model_path}")
|
| 220 |
+
|
| 221 |
+
return {
|
| 222 |
+
'status': 'ok',
|
| 223 |
+
'message': f'Training completed with avg loss: {avg_training_loss:.4f}'
|
| 224 |
+
}
|
| 225 |
|
| 226 |
except Exception as e:
|
| 227 |
logger.error(f"Training error: {str(e)}")
|
| 228 |
+
return {'status': 'error', 'message': str(e)}
|
|
|
|
| 229 |
|
| 230 |
except Exception as e:
|
| 231 |
logger.error(f"Error in fit method: {str(e)}")
|