b2u commited on
Commit
3178a98
·
1 Parent(s): fa5ac26

keep debugging

Browse files
Files changed (1) hide show
  1. model.py +15 -7
model.py CHANGED
@@ -76,17 +76,17 @@ class BertClassifier(LabelStudioMLBase):
76
  logger.info(f"Initializing BertClassifier with project_id: {project_id}")
77
  logger.info(f"Label config: {label_config}")
78
 
 
 
 
 
 
79
  # Initialize Label Studio client
80
  self.label_studio_client = self.connect_to_label_studio()
81
 
82
  self.label_encoder = LabelEncoder()
83
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
  self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
85
 
86
- # Define model directory
87
- self.model_dir = os.path.join(os.path.dirname(__file__), 'model')
88
- os.makedirs(self.model_dir, exist_ok=True)
89
-
90
  # Define your categories
91
  self.categories = [
92
  'affiliate_classification', 'brand', 'business_and_career',
@@ -166,7 +166,7 @@ class BertClassifier(LabelStudioMLBase):
166
  },
167
  'score': 0.5 # Confidence score between 0 and 1
168
  }],
169
- 'model_version': self.model_dir
170
  })
171
 
172
  except Exception as e:
@@ -175,7 +175,7 @@ class BertClassifier(LabelStudioMLBase):
175
  # Return empty predictions in case of error
176
  predictions = [{
177
  'result': [],
178
- 'model_version': self.model_dir
179
  } for _ in tasks]
180
 
181
  return predictions
@@ -206,6 +206,10 @@ class BertClassifier(LabelStudioMLBase):
206
  logger.info('Starting model training...')
207
 
208
  try:
 
 
 
 
209
  # Debug completions
210
  logger.info("=== DEBUG COMPLETIONS START ===")
211
  logger.info(f"Type of completions: {type(completions)}")
@@ -233,6 +237,10 @@ class BertClassifier(LabelStudioMLBase):
233
 
234
  # Get annotations
235
  annotations = task.get('annotations', [])
 
 
 
 
236
  if annotations:
237
  logger.info(f"Found {len(annotations)} annotations for task {task.get('id')}")
238
  logger.info(f"Annotation content: {json.dumps(annotations[0], indent=2)}")
 
76
  logger.info(f"Initializing BertClassifier with project_id: {project_id}")
77
  logger.info(f"Label config: {label_config}")
78
 
79
+ # Initialize basic attributes
80
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
+ self.version = 'v0.0.1' # Define version explicitly
82
+ self.model_dir = f'BertClassifier-{self.version}' # Use versioned model directory
83
+
84
  # Initialize Label Studio client
85
  self.label_studio_client = self.connect_to_label_studio()
86
 
87
  self.label_encoder = LabelEncoder()
 
88
  self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
89
 
 
 
 
 
90
  # Define your categories
91
  self.categories = [
92
  'affiliate_classification', 'brand', 'business_and_career',
 
166
  },
167
  'score': 0.5 # Confidence score between 0 and 1
168
  }],
169
+ 'model_version': self.version
170
  })
171
 
172
  except Exception as e:
 
175
  # Return empty predictions in case of error
176
  predictions = [{
177
  'result': [],
178
+ 'model_version': self.version
179
  } for _ in tasks]
180
 
181
  return predictions
 
206
  logger.info('Starting model training...')
207
 
208
  try:
209
+ # Get use_ground_truth parameter
210
+ use_ground_truth = kwargs.get('use_ground_truth', True)
211
+ logger.info(f"Training with use_ground_truth={use_ground_truth}")
212
+
213
  # Debug completions
214
  logger.info("=== DEBUG COMPLETIONS START ===")
215
  logger.info(f"Type of completions: {type(completions)}")
 
237
 
238
  # Get annotations
239
  annotations = task.get('annotations', [])
240
+ if use_ground_truth:
241
+ # Also include ground truth annotations
242
+ annotations.extend(task.get('ground_truth', []))
243
+
244
  if annotations:
245
  logger.info(f"Found {len(annotations)} annotations for task {task.get('id')}")
246
  logger.info(f"Annotation content: {json.dumps(annotations[0], indent=2)}")