b2u commited on
Commit
1b11c8f
·
1 Parent(s): c352a3a

rolling back

Browse files
Files changed (1) hide show
  1. model.py +9 -422
model.py CHANGED
@@ -1,450 +1,37 @@
1
- import os
2
  import torch
3
  import logging
4
- import pathlib
5
- import pickle
6
  import json
7
- from typing import List, Dict, Optional
8
  from label_studio_ml.model import LabelStudioMLBase
9
- from transformers import (
10
- AutoModelForSequenceClassification,
11
- AutoTokenizer,
12
- Trainer,
13
- TrainingArguments
14
- )
15
- from datasets import Dataset
16
  from sklearn.preprocessing import LabelEncoder
17
- from label_studio_sdk.label_interface.objects import PredictionValue
18
- from label_studio_ml.response import ModelResponse
19
- from label_studio_sdk import Client
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
-
24
- if torch.cuda.is_available():
25
- device = torch.device("cuda")
26
- print('There are %d GPU(s) available.' % torch.cuda.device_count())
27
- print('We will use the GPU:', torch.cuda.get_device_name(0))
28
- else:
29
- print('No GPU available, using the CPU instead.')
30
- device = torch.device("cpu")
31
-
32
-
33
  class BertClassifier(LabelStudioMLBase):
34
- """
35
- BERT-based text classification model for Label Studio
36
-
37
- This model uses the Hugging Face Transformers library to fine-tune a BERT model for text classification.
38
- Use any model for [AutoModelForSequenceClassification](https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#automodelforsequenceclassification)
39
- The model is trained on the labeled data from Label Studio and then used to make predictions on new data.
40
-
41
- Parameters:
42
- -----------
43
- LABEL_STUDIO_HOST : str
44
- The URL of the Label Studio instance
45
- LABEL_STUDIO_API_KEY : str
46
- The API key for the Label Studio instance
47
- START_TRAINING_EACH_N_UPDATES : int
48
- The number of labeled tasks to download from Label Studio before starting training
49
- LEARNING_RATE : float
50
- The learning rate for the model training
51
- NUM_TRAIN_EPOCHS : int
52
- The number of epochs for model training
53
- WEIGHT_DECAY : float
54
- The weight decay for the model training
55
- baseline_model_name : str
56
- The name of the baseline model to use for training
57
- MODEL_DIR : str
58
- The directory to save the trained model
59
- finetuned_model_name : str
60
- The name of the finetuned model
61
- """
62
- LABEL_STUDIO_HOST = os.getenv('LABEL_STUDIO_HOST', 'http://localhost:8080')
63
- LABEL_STUDIO_API_KEY = os.getenv('LABEL_STUDIO_API_KEY')
64
- START_TRAINING_EACH_N_UPDATES = int(os.getenv('START_TRAINING_EACH_N_UPDATES', 10))
65
- LEARNING_RATE = float(os.getenv('LEARNING_RATE', 2e-5))
66
- NUM_TRAIN_EPOCHS = int(os.getenv('NUM_TRAIN_EPOCHS', 3))
67
- WEIGHT_DECAY = float(os.getenv('WEIGHT_DECAY', 0.01))
68
- baseline_model_name = os.getenv('BASELINE_MODEL_NAME', 'bert-base-multilingual-cased')
69
- MODEL_DIR = os.getenv('MODEL_DIR', './results')
70
- finetuned_model_name = os.getenv('FINETUNED_MODEL_NAME', 'finetuned-model')
71
- _model = None
72
-
73
  def __init__(self, project_id=None, label_config=None, **kwargs):
74
  super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
75
 
76
  logger.info(f"Initializing BertClassifier with project_id: {project_id}")
77
  logger.info(f"Label config: {label_config}")
78
 
79
- # Initialize basic attributes
80
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
- self.version = 'v0.0.1'
82
- self.model_dir = f'BertClassifier-{self.version}'
83
-
84
- # Define categories
85
- self.categories = [
86
- 'affiliate_classification', 'brand', 'business_and_career',
87
- 'content_quality', 'date', 'demographic', 'event',
88
- 'faith_and_religion', 'gaming', 'health',
89
- 'internal_categorization', 'location', 'number',
90
- 'performance', 'post_type', 'pricing_tier',
91
- 'product', 'profession', 'pii', 'social_network',
92
- 'style_and_fashion', 'no_category'
93
- ]
94
-
95
- # Initialize model and tokenizer as None - they'll be loaded when needed
96
  self._model = None
97
  self.tokenizer = None
98
-
99
- logger.info("BertClassifier initialized successfully")
100
-
101
- def get_labels(self):
102
- li = self.label_interface
103
- from_name, _, _ = li.get_first_tag_occurence('Choices', 'Text')
104
- tag = li.get_tag(from_name)
105
- return tag.labels
106
-
107
- def setup(self):
108
- """Setup the model - this is called when Label Studio connects"""
109
- try:
110
- # Initialize model directory
111
- os.makedirs(self.model_dir, exist_ok=True)
112
-
113
- # Return the required information for Label Studio
114
- return {
115
- 'model_class': 'BertClassifier', # Must match your class name
116
- 'model_params': {
117
- 'device': str(self.device),
118
- 'version': self.version
119
- },
120
- 'label_config': {
121
- 'from_name': 'sentiment',
122
- 'to_name': 'text',
123
- 'type': 'choices',
124
- 'labels': self.categories
125
- },
126
- 'api_version': '2' # Important: specify API version
127
- }
128
-
129
- except Exception as e:
130
- logger.error(f"Error in setup: {str(e)}")
131
- logger.error("Full error details:", exc_info=True)
132
- raise
133
-
134
- def _lazy_init(self):
135
- if not hasattr(self, '_model') or self._model is None:
136
- try:
137
- # Try to load fine-tuned model
138
- model_path = os.path.join(self.MODEL_DIR, 'fine_tuned_model')
139
- if os.path.exists(model_path):
140
- logger.info('Loading fine-tuned model...')
141
- self._model = AutoModelForSequenceClassification.from_pretrained(
142
- model_path,
143
- num_labels=len(self.categories)
144
- )
145
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
146
- else:
147
- logger.info('Loading base model...')
148
- self._model = AutoModelForSequenceClassification.from_pretrained(
149
- 'bert-base-multilingual-cased',
150
- num_labels=len(self.categories)
151
- )
152
- self.tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')
153
-
154
- self._model.to(self.device)
155
-
156
- # Load label encoder if exists
157
- label_encoder_path = os.path.join(self.MODEL_DIR, 'label_encoder.pkl')
158
- if os.path.exists(label_encoder_path):
159
- with open(label_encoder_path, 'rb') as f:
160
- self.label_encoder = pickle.load(f)
161
-
162
- except Exception as e:
163
- logger.error(f'Error initializing model: {str(e)}')
164
- raise
165
 
166
  def predict(self, tasks, **kwargs):
167
  """Make predictions for tasks"""
168
  predictions = []
169
-
170
- try:
171
- # Save tasks
172
- for task in tasks:
173
- self.save_task(task)
174
- logger.info(f"Saved task: {task.get('id', 'unknown')}")
175
-
176
- # Get text from task
177
- text = task.get('data', {}).get('text', '')
178
-
179
- # For now, return a default prediction (you can improve this later)
180
- predictions.append({
181
- 'result': [{
182
- 'from_name': 'sentiment',
183
- 'to_name': 'text',
184
- 'type': 'choices',
185
- 'value': {
186
- 'choices': ['no_category'] # Default prediction
187
- },
188
- 'score': 0.5 # Confidence score between 0 and 1
189
- }],
190
- 'model_version': self.version
191
- })
192
-
193
- except Exception as e:
194
- logger.error(f"Error in predict: {str(e)}")
195
- logger.error("Full error details:", exc_info=True)
196
- # Return empty predictions in case of error
197
- predictions = [{
198
  'result': [],
199
- 'model_version': self.version
200
- } for _ in tasks]
201
-
202
  return predictions
203
 
204
- def get_tasks(self):
205
- """Get tasks from Label Studio"""
206
- try:
207
- # Get tasks from Label Studio API
208
- params = {'project': self.project_id} if self.project_id else {}
209
- response = self.label_studio_client.make_request('GET', '/api/tasks', params=params)
210
- tasks = response.json()
211
-
212
- logger.info(f"Retrieved {len(tasks)} tasks from Label Studio API")
213
-
214
- # Debug first task if available
215
- if tasks:
216
- logger.info(f"First task content: {json.dumps(tasks[0], indent=2)}")
217
-
218
- return tasks
219
-
220
- except Exception as e:
221
- logger.error(f"Error retrieving tasks from Label Studio: {str(e)}")
222
- logger.error("Full error details:", exc_info=True)
223
- return []
224
-
225
  def fit(self, completions, workdir=None, **kwargs):
226
  """Train model on labeled data"""
227
  logger.info('Starting model training...')
228
-
229
- try:
230
- # Get use_ground_truth parameter
231
- use_ground_truth = kwargs.get('use_ground_truth', True)
232
- logger.info(f"Training with use_ground_truth={use_ground_truth}")
233
-
234
- # Debug completions
235
- logger.info("=== DEBUG COMPLETIONS START ===")
236
- logger.info(f"Type of completions: {type(completions)}")
237
- logger.info(f"Completions content: {completions}")
238
- logger.info("=== DEBUG COMPLETIONS END ===")
239
-
240
- # Extract training data
241
- texts, labels = [], []
242
-
243
- # Get tasks from Label Studio
244
- tasks = self.get_tasks()
245
- logger.info(f"Retrieved {len(tasks)} tasks from Label Studio")
246
-
247
- # Get interface info
248
- from_name = 'sentiment' # This matches your label config
249
- to_name = 'text' # This matches your label config
250
-
251
- for task in tasks:
252
- try:
253
- # Get text from task
254
- text = task['data'].get('text')
255
- if not text:
256
- logger.warning(f"No text found in task {task.get('id')}")
257
- continue
258
-
259
- # Get annotations
260
- annotations = task.get('annotations', [])
261
- if use_ground_truth:
262
- # Also include ground truth annotations
263
- annotations.extend(task.get('ground_truth', []))
264
-
265
- if annotations:
266
- logger.info(f"Found {len(annotations)} annotations for task {task.get('id')}")
267
- logger.info(f"Annotation content: {json.dumps(annotations[0], indent=2)}")
268
-
269
- if not annotations:
270
- logger.warning(f"No annotations found for task {task.get('id')}")
271
- continue
272
-
273
- for annotation in annotations:
274
- # Only use completed annotations
275
- if annotation.get('was_cancelled') or not annotation.get('completed_by'):
276
- continue
277
-
278
- try:
279
- # Get choices from result
280
- results = annotation.get('result', [])
281
- if not results:
282
- logger.warning(f"No results found in annotation for task {task.get('id')}")
283
- continue
284
-
285
- for result in results:
286
- if result.get('from_name') == from_name and result.get('to_name') == to_name:
287
- choices = result.get('value', {}).get('choices', [])
288
- if choices:
289
- label = choices[0]
290
- logger.info(f"Successfully extracted: Text='{text[:50]}...', Label='{label}'")
291
- texts.append(text)
292
- labels.append(label)
293
- break
294
-
295
- except Exception as e:
296
- logger.error(f"Error processing annotation: {str(e)}")
297
- continue
298
-
299
- except Exception as e:
300
- logger.error(f"Error processing task: {str(e)}")
301
- continue
302
-
303
- logger.info(f"Prepared {len(texts)} examples for training")
304
-
305
- if not texts:
306
- raise ValueError("No valid training examples found")
307
-
308
- # Convert labels to numeric using label encoder
309
- numeric_labels = self.label_encoder.transform(labels)
310
-
311
- # Create dataset
312
- train_dataset = Dataset.from_dict({
313
- 'text': texts,
314
- 'label': numeric_labels
315
- })
316
-
317
- # Initialize tokenizer and model if not already done
318
- self._lazy_init()
319
-
320
- # Tokenize the texts
321
- def tokenize_function(examples):
322
- return self.tokenizer(
323
- examples['text'],
324
- padding='max_length',
325
- truncation=True,
326
- max_length=512
327
- )
328
-
329
- tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
330
-
331
- # Define training arguments
332
- training_args = TrainingArguments(
333
- output_dir=os.path.join(self.model_dir, "results"),
334
- num_train_epochs=3,
335
- per_device_train_batch_size=8,
336
- per_device_eval_batch_size=8,
337
- warmup_steps=500,
338
- weight_decay=0.01,
339
- logging_dir=os.path.join(self.model_dir, "logs"),
340
- logging_steps=10,
341
- save_strategy="epoch",
342
- )
343
-
344
- # Initialize trainer
345
- trainer = Trainer(
346
- model=self._model,
347
- args=training_args,
348
- train_dataset=tokenized_dataset,
349
- )
350
-
351
- # Train the model
352
- logger.info('Training started...')
353
- trainer.train()
354
- logger.info("Training completed successfully")
355
-
356
- # Save the fine-tuned model
357
- model_path = os.path.join(self.model_dir, 'fine_tuned_model')
358
- trainer.save_model(model_path)
359
- self.tokenizer.save_pretrained(model_path)
360
- logger.info(f"Model saved to {model_path}")
361
-
362
- # Save label encoder
363
- label_encoder_path = os.path.join(self.model_dir, 'label_encoder.pkl')
364
- with open(label_encoder_path, 'wb') as f:
365
- pickle.dump(self.label_encoder, f)
366
-
367
- return {
368
- 'model_path': model_path,
369
- 'label_encoder_path': label_encoder_path,
370
- 'categories': self.categories,
371
- 'metrics': trainer.state.log_history,
372
- 'status': 'success',
373
- 'train_size': len(texts)
374
- }
375
-
376
- except Exception as e:
377
- logger.error(f"Training failed: {str(e)}")
378
- logger.error("Full error details:", exc_info=True)
379
- return {
380
- 'status': 'error',
381
- 'error': str(e),
382
- 'train_size': len(texts) if 'texts' in locals() else 0
383
- }
384
-
385
- def save_task(self, task):
386
- """Save a task to local storage"""
387
- try:
388
- storage_path = os.path.join(self.model_dir, 'tasks.json')
389
- tasks = []
390
-
391
- # Load existing tasks
392
- if os.path.exists(storage_path):
393
- with open(storage_path, 'r') as f:
394
- tasks = json.load(f)
395
- logger.info(f"Loaded {len(tasks)} existing tasks")
396
-
397
- # Check if task already exists
398
- task_id = task.get('id')
399
- task_exists = False
400
-
401
- if task_id:
402
- for i, existing_task in enumerate(tasks):
403
- if existing_task.get('id') == task_id:
404
- # Preserve existing annotations
405
- existing_annotations = existing_task.get('annotations', [])
406
- if existing_annotations:
407
- task['annotations'] = existing_annotations
408
-
409
- # Update existing task
410
- tasks[i] = task
411
- task_exists = True
412
- logger.info(f"Updated existing task {task_id} with {len(existing_annotations)} annotations")
413
- break
414
-
415
- # Add new task if it doesn't exist
416
- if not task_exists:
417
- tasks.append(task)
418
- logger.info(f"Added new task {task_id}")
419
-
420
- # Save tasks
421
- with open(storage_path, 'w') as f:
422
- json.dump(tasks, f)
423
-
424
- logger.info(f"Saved tasks to storage. Total tasks: {len(tasks)}")
425
-
426
- except Exception as e:
427
- logger.error(f"Error saving task: {str(e)}")
428
- logger.error("Full error details:", exc_info=True)
429
-
430
- def connect_to_label_studio(self):
431
- """Connect to Label Studio API"""
432
- try:
433
- from label_studio_sdk import Client
434
-
435
- # Get Label Studio connection details from environment
436
- ls_url = os.getenv('LABEL_STUDIO_URL', 'http://localhost:8080')
437
- ls_token = os.getenv('LABEL_STUDIO_API_TOKEN')
438
-
439
- if not ls_token:
440
- raise ValueError("LABEL_STUDIO_API_TOKEN environment variable is not set")
441
-
442
- # Initialize client
443
- client = Client(url=ls_url, api_key=ls_token)
444
- logger.info(f"Connected to Label Studio at {ls_url}")
445
- return client
446
-
447
- except Exception as e:
448
- logger.error(f"Error connecting to Label Studio: {str(e)}")
449
- logger.error("Full error details:", exc_info=True)
450
- raise
 
 
1
  import torch
2
  import logging
3
+ import os
 
4
  import json
 
5
  from label_studio_ml.model import LabelStudioMLBase
6
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
 
 
 
 
 
7
  from sklearn.preprocessing import LabelEncoder
 
 
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
 
 
 
 
 
 
 
 
 
11
  class BertClassifier(LabelStudioMLBase):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def __init__(self, project_id=None, label_config=None, **kwargs):
13
  super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
14
 
15
  logger.info(f"Initializing BertClassifier with project_id: {project_id}")
16
  logger.info(f"Label config: {label_config}")
17
 
 
18
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ self.model_dir = os.path.join(os.path.dirname(__file__), 'model')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  self._model = None
21
  self.tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def predict(self, tasks, **kwargs):
24
  """Make predictions for tasks"""
25
  predictions = []
26
+ for task in tasks:
27
+ predictions.append({
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  'result': [],
29
+ 'score': 0,
30
+ 'model_version': self.model_dir
31
+ })
32
  return predictions
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def fit(self, completions, workdir=None, **kwargs):
35
  """Train model on labeled data"""
36
  logger.info('Starting model training...')
37
+ return {'status': 'ok'}