b2u commited on
Commit
4debd04
·
1 Parent(s): 1818eaa

Training logic added

Browse files
Files changed (3) hide show
  1. model.py +86 -13
  2. utils/__init__.py +0 -0
  3. utils/dataset.py +25 -0
model.py CHANGED
@@ -2,9 +2,13 @@ import torch
2
  import logging
3
  import os
4
  import json
 
5
  from label_studio_ml.model import LabelStudioMLBase
6
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
 
7
  from sklearn.preprocessing import LabelEncoder
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
@@ -13,12 +17,12 @@ class BertClassifier(LabelStudioMLBase):
13
  super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
14
 
15
  logger.info(f"Initializing BertClassifier with project_id: {project_id}")
16
- logger.info(f"Label config: {label_config}")
17
 
18
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
  logger.info(f"Using device: {self.device}")
20
 
21
- # Define categories that match your Label Studio config
22
  self.categories = [
23
  'affiliate_classification', 'brand', 'business_and_career',
24
  'content_quality', 'date', 'demographic', 'event',
@@ -30,8 +34,7 @@ class BertClassifier(LabelStudioMLBase):
30
  ]
31
 
32
  self.model_dir = os.path.join(os.path.dirname(__file__), 'model')
33
- self._model = None
34
- self.tokenizer = None
35
 
36
  # Initialize model and tokenizer
37
  try:
@@ -60,13 +63,10 @@ class BertClassifier(LabelStudioMLBase):
60
 
61
  for task in tasks:
62
  logger.info(f"Processing task ID: {task.get('id')}")
63
-
64
- # Get the text to classify
65
  text = task['data'].get('text', '')
66
- logger.info(f"Text to predict: {text}")
67
 
68
  try:
69
- # Tokenize the text
70
  inputs = self.tokenizer(
71
  text,
72
  truncation=True,
@@ -74,8 +74,7 @@ class BertClassifier(LabelStudioMLBase):
74
  return_tensors='pt'
75
  ).to(self.device)
76
 
77
- # Get model prediction
78
- self._model.eval() # Set to evaluation mode
79
  with torch.no_grad():
80
  outputs = self._model(**inputs)
81
  probs = torch.softmax(outputs.logits, dim=1)
@@ -102,7 +101,6 @@ class BertClassifier(LabelStudioMLBase):
102
  except Exception as e:
103
  logger.error(f"Error processing individual task: {str(e)}")
104
  logger.error("Full error details:", exc_info=True)
105
- # Add empty prediction for failed task
106
  predictions.append({
107
  'result': [],
108
  'score': 0,
@@ -119,5 +117,80 @@ class BertClassifier(LabelStudioMLBase):
119
 
120
  def fit(self, completions, workdir=None, **kwargs):
121
  """Train model on labeled data"""
122
- logger.info('Starting model training...')
123
- return {'status': 'ok'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import logging
3
  import os
4
  import json
5
+ from datetime import datetime
6
  from label_studio_ml.model import LabelStudioMLBase
7
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
8
+ from torch.utils.data import DataLoader
9
+ from torch.optim import AdamW
10
  from sklearn.preprocessing import LabelEncoder
11
+ from utils.dataset import TextDataset
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
17
  super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
18
 
19
  logger.info(f"Initializing BertClassifier with project_id: {project_id}")
20
+ logger.info(f"Label config length: {len(label_config) if label_config else 0}")
21
 
22
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
  logger.info(f"Using device: {self.device}")
24
 
25
+ # Define categories
26
  self.categories = [
27
  'affiliate_classification', 'brand', 'business_and_career',
28
  'content_quality', 'date', 'demographic', 'event',
 
34
  ]
35
 
36
  self.model_dir = os.path.join(os.path.dirname(__file__), 'model')
37
+ os.makedirs(self.model_dir, exist_ok=True)
 
38
 
39
  # Initialize model and tokenizer
40
  try:
 
63
 
64
  for task in tasks:
65
  logger.info(f"Processing task ID: {task.get('id')}")
 
 
66
  text = task['data'].get('text', '')
67
+ logger.info(f"Text to predict: {text[:100]}...")
68
 
69
  try:
 
70
  inputs = self.tokenizer(
71
  text,
72
  truncation=True,
 
74
  return_tensors='pt'
75
  ).to(self.device)
76
 
77
+ self._model.eval()
 
78
  with torch.no_grad():
79
  outputs = self._model(**inputs)
80
  probs = torch.softmax(outputs.logits, dim=1)
 
101
  except Exception as e:
102
  logger.error(f"Error processing individual task: {str(e)}")
103
  logger.error("Full error details:", exc_info=True)
 
104
  predictions.append({
105
  'result': [],
106
  'score': 0,
 
117
 
118
  def fit(self, completions, workdir=None, **kwargs):
119
  """Train model on labeled data"""
120
+ try:
121
+ logger.info('=== STARTING MODEL TRAINING ===')
122
+ logger.info(f'Received {len(completions)} completions for training')
123
+
124
+ # Extract training data
125
+ texts = []
126
+ labels = []
127
+ label_encoder = LabelEncoder()
128
+
129
+ for completion in completions:
130
+ logger.info(f"Processing completion: {completion.get('id')}")
131
+ text = completion['data'].get('text', '')
132
+ annotations = completion.get('annotations', [])
133
+ if annotations:
134
+ label = annotations[0].get('result', [])[0].get('value', {}).get('choices', [])[0]
135
+ texts.append(text)
136
+ labels.append(label)
137
+ logger.info(f"Added training example: '{text[:50]}...' -> {label}")
138
+
139
+ if not texts:
140
+ logger.warning("No valid training examples found")
141
+ return {'status': 'error', 'message': 'No valid training examples found'}
142
+
143
+ logger.info(f'Prepared {len(texts)} examples for training')
144
+
145
+ # Encode labels
146
+ encoded_labels = label_encoder.fit_transform(labels)
147
+
148
+ # Create dataset
149
+ dataset = TextDataset(texts, encoded_labels, self.tokenizer)
150
+ train_loader = DataLoader(dataset, batch_size=8, shuffle=True)
151
+
152
+ # Training setup
153
+ optimizer = AdamW(self._model.parameters(), lr=2e-5)
154
+ self._model.train()
155
+
156
+ # Training loop
157
+ num_epochs = 3
158
+ logger.info(f"Starting training for {num_epochs} epochs")
159
+
160
+ for epoch in range(num_epochs):
161
+ total_loss = 0
162
+ for batch in train_loader:
163
+ optimizer.zero_grad()
164
+
165
+ input_ids = batch['input_ids'].to(self.device)
166
+ attention_mask = batch['attention_mask'].to(self.device)
167
+ labels = batch['labels'].to(self.device)
168
+
169
+ outputs = self._model(
170
+ input_ids=input_ids,
171
+ attention_mask=attention_mask,
172
+ labels=labels
173
+ )
174
+
175
+ loss = outputs.loss
176
+ total_loss += loss.item()
177
+
178
+ loss.backward()
179
+ optimizer.step()
180
+
181
+ avg_loss = total_loss / len(train_loader)
182
+ logger.info(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
183
+
184
+ # Save the model
185
+ save_path = os.path.join(self.model_dir, 'trained_model')
186
+ self._model.save_pretrained(save_path)
187
+ self.tokenizer.save_pretrained(save_path)
188
+ logger.info(f"Model saved to {save_path}")
189
+
190
+ logger.info('=== TRAINING COMPLETED SUCCESSFULLY ===')
191
+ return {'status': 'ok', 'message': f'Model trained on {len(texts)} examples'}
192
+
193
+ except Exception as e:
194
+ logger.error(f"Error during training: {str(e)}")
195
+ logger.error("Full error details:", exc_info=True)
196
+ return {'status': 'error', 'message': str(e)}
utils/__init__.py ADDED
File without changes
utils/dataset.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+
4
+ class TextDataset(Dataset):
5
+ def __init__(self, texts, labels, tokenizer, max_length=128):
6
+ """
7
+ Initialize dataset for text classification
8
+ Args:
9
+ texts: list of input texts
10
+ labels: list of corresponding labels
11
+ tokenizer: HuggingFace tokenizer
12
+ max_length: maximum sequence length
13
+ """
14
+ self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length)
15
+ self.labels = labels
16
+
17
+ def __getitem__(self, idx):
18
+ """Return a single training example"""
19
+ item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
20
+ item['labels'] = torch.tensor(self.labels[idx])
21
+ return item
22
+
23
+ def __len__(self):
24
+ """Return the number of examples in dataset"""
25
+ return len(self.labels)