Spaces:
Sleeping
Sleeping
keep debugging
Browse files
model.py
CHANGED
|
@@ -76,17 +76,17 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 76 |
logger.info(f"Initializing BertClassifier with project_id: {project_id}")
|
| 77 |
logger.info(f"Label config: {label_config}")
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# Initialize Label Studio client
|
| 80 |
self.label_studio_client = self.connect_to_label_studio()
|
| 81 |
|
| 82 |
self.label_encoder = LabelEncoder()
|
| 83 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 84 |
self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
|
| 85 |
|
| 86 |
-
# Define model directory
|
| 87 |
-
self.model_dir = os.path.join(os.path.dirname(__file__), 'model')
|
| 88 |
-
os.makedirs(self.model_dir, exist_ok=True)
|
| 89 |
-
|
| 90 |
# Define your categories
|
| 91 |
self.categories = [
|
| 92 |
'affiliate_classification', 'brand', 'business_and_career',
|
|
@@ -166,7 +166,7 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 166 |
},
|
| 167 |
'score': 0.5 # Confidence score between 0 and 1
|
| 168 |
}],
|
| 169 |
-
'model_version': self.
|
| 170 |
})
|
| 171 |
|
| 172 |
except Exception as e:
|
|
@@ -175,7 +175,7 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 175 |
# Return empty predictions in case of error
|
| 176 |
predictions = [{
|
| 177 |
'result': [],
|
| 178 |
-
'model_version': self.
|
| 179 |
} for _ in tasks]
|
| 180 |
|
| 181 |
return predictions
|
|
@@ -206,6 +206,10 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 206 |
logger.info('Starting model training...')
|
| 207 |
|
| 208 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
# Debug completions
|
| 210 |
logger.info("=== DEBUG COMPLETIONS START ===")
|
| 211 |
logger.info(f"Type of completions: {type(completions)}")
|
|
@@ -233,6 +237,10 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 233 |
|
| 234 |
# Get annotations
|
| 235 |
annotations = task.get('annotations', [])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
if annotations:
|
| 237 |
logger.info(f"Found {len(annotations)} annotations for task {task.get('id')}")
|
| 238 |
logger.info(f"Annotation content: {json.dumps(annotations[0], indent=2)}")
|
|
|
|
| 76 |
logger.info(f"Initializing BertClassifier with project_id: {project_id}")
|
| 77 |
logger.info(f"Label config: {label_config}")
|
| 78 |
|
| 79 |
+
# Initialize basic attributes
|
| 80 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 81 |
+
self.version = 'v0.0.1' # Define version explicitly
|
| 82 |
+
self.model_dir = f'BertClassifier-{self.version}' # Use versioned model directory
|
| 83 |
+
|
| 84 |
# Initialize Label Studio client
|
| 85 |
self.label_studio_client = self.connect_to_label_studio()
|
| 86 |
|
| 87 |
self.label_encoder = LabelEncoder()
|
|
|
|
| 88 |
self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
# Define your categories
|
| 91 |
self.categories = [
|
| 92 |
'affiliate_classification', 'brand', 'business_and_career',
|
|
|
|
| 166 |
},
|
| 167 |
'score': 0.5 # Confidence score between 0 and 1
|
| 168 |
}],
|
| 169 |
+
'model_version': self.version
|
| 170 |
})
|
| 171 |
|
| 172 |
except Exception as e:
|
|
|
|
| 175 |
# Return empty predictions in case of error
|
| 176 |
predictions = [{
|
| 177 |
'result': [],
|
| 178 |
+
'model_version': self.version
|
| 179 |
} for _ in tasks]
|
| 180 |
|
| 181 |
return predictions
|
|
|
|
| 206 |
logger.info('Starting model training...')
|
| 207 |
|
| 208 |
try:
|
| 209 |
+
# Get use_ground_truth parameter
|
| 210 |
+
use_ground_truth = kwargs.get('use_ground_truth', True)
|
| 211 |
+
logger.info(f"Training with use_ground_truth={use_ground_truth}")
|
| 212 |
+
|
| 213 |
# Debug completions
|
| 214 |
logger.info("=== DEBUG COMPLETIONS START ===")
|
| 215 |
logger.info(f"Type of completions: {type(completions)}")
|
|
|
|
| 237 |
|
| 238 |
# Get annotations
|
| 239 |
annotations = task.get('annotations', [])
|
| 240 |
+
if use_ground_truth:
|
| 241 |
+
# Also include ground truth annotations
|
| 242 |
+
annotations.extend(task.get('ground_truth', []))
|
| 243 |
+
|
| 244 |
if annotations:
|
| 245 |
logger.info(f"Found {len(annotations)} annotations for task {task.get('id')}")
|
| 246 |
logger.info(f"Annotation content: {json.dumps(annotations[0], indent=2)}")
|