Spaces:
Sleeping
Sleeping
debugging
Browse files
model.py
CHANGED
|
@@ -76,6 +76,9 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 76 |
logger.info(f"Initializing BertClassifier with project_id: {project_id}")
|
| 77 |
logger.info(f"Label config: {label_config}")
|
| 78 |
|
|
|
|
|
|
|
|
|
|
| 79 |
self.label_encoder = LabelEncoder()
|
| 80 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 81 |
self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
|
|
@@ -178,19 +181,24 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 178 |
return predictions
|
| 179 |
|
| 180 |
def get_tasks(self):
|
| 181 |
-
"""Get tasks from
|
| 182 |
try:
|
| 183 |
-
|
| 184 |
-
if
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
except Exception as e:
|
| 193 |
-
logger.error(f"Error retrieving tasks: {str(e)}")
|
|
|
|
| 194 |
return []
|
| 195 |
|
| 196 |
def fit(self, completions, workdir=None, **kwargs):
|
|
@@ -207,25 +215,13 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 207 |
# Extract training data
|
| 208 |
texts, labels = [], []
|
| 209 |
|
| 210 |
-
#
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
logger.info(f"Retrieved {len(tasks)} tasks from storage")
|
| 214 |
-
# Debug first task if available
|
| 215 |
-
if tasks:
|
| 216 |
-
logger.info(f"First task content: {json.dumps(tasks[0], indent=2)}")
|
| 217 |
-
else:
|
| 218 |
-
# If completions is a list, use it directly
|
| 219 |
-
tasks = completions if isinstance(completions, list) else [completions]
|
| 220 |
-
logger.info(f"Using {len(tasks)} tasks from completions")
|
| 221 |
-
if tasks:
|
| 222 |
-
logger.info(f"First completion content: {json.dumps(tasks[0], indent=2)}")
|
| 223 |
-
|
| 224 |
-
logger.info(f"Processing {len(tasks)} tasks")
|
| 225 |
|
| 226 |
# Get interface info
|
| 227 |
-
from_name
|
| 228 |
-
|
| 229 |
|
| 230 |
for task in tasks:
|
| 231 |
try:
|
|
@@ -237,17 +233,20 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 237 |
|
| 238 |
# Get annotations
|
| 239 |
annotations = task.get('annotations', [])
|
| 240 |
-
|
|
|
|
|
|
|
| 241 |
|
| 242 |
if not annotations:
|
| 243 |
logger.warning(f"No annotations found for task {task.get('id')}")
|
| 244 |
continue
|
| 245 |
|
| 246 |
for annotation in annotations:
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
|
|
|
|
| 251 |
# Get choices from result
|
| 252 |
results = annotation.get('result', [])
|
| 253 |
if not results:
|
|
@@ -263,7 +262,7 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 263 |
texts.append(text)
|
| 264 |
labels.append(label)
|
| 265 |
break
|
| 266 |
-
|
| 267 |
except Exception as e:
|
| 268 |
logger.error(f"Error processing annotation: {str(e)}")
|
| 269 |
continue
|
|
@@ -398,3 +397,25 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 398 |
except Exception as e:
|
| 399 |
logger.error(f"Error saving task: {str(e)}")
|
| 400 |
logger.error("Full error details:", exc_info=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
logger.info(f"Initializing BertClassifier with project_id: {project_id}")
|
| 77 |
logger.info(f"Label config: {label_config}")
|
| 78 |
|
| 79 |
+
# Initialize Label Studio client
|
| 80 |
+
self.label_studio_client = self.connect_to_label_studio()
|
| 81 |
+
|
| 82 |
self.label_encoder = LabelEncoder()
|
| 83 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 84 |
self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
|
|
|
|
| 181 |
return predictions
|
| 182 |
|
| 183 |
def get_tasks(self):
|
| 184 |
+
"""Get tasks from Label Studio"""
|
| 185 |
try:
|
| 186 |
+
# Get tasks from Label Studio API
|
| 187 |
+
params = {'project': self.project_id} if self.project_id else {}
|
| 188 |
+
response = self.label_studio_client.make_request('GET', '/api/tasks', params=params)
|
| 189 |
+
tasks = response.json()
|
| 190 |
+
|
| 191 |
+
logger.info(f"Retrieved {len(tasks)} tasks from Label Studio API")
|
| 192 |
+
|
| 193 |
+
# Debug first task if available
|
| 194 |
+
if tasks:
|
| 195 |
+
logger.info(f"First task content: {json.dumps(tasks[0], indent=2)}")
|
| 196 |
+
|
| 197 |
+
return tasks
|
| 198 |
+
|
| 199 |
except Exception as e:
|
| 200 |
+
logger.error(f"Error retrieving tasks from Label Studio: {str(e)}")
|
| 201 |
+
logger.error("Full error details:", exc_info=True)
|
| 202 |
return []
|
| 203 |
|
| 204 |
def fit(self, completions, workdir=None, **kwargs):
|
|
|
|
| 215 |
# Extract training data
|
| 216 |
texts, labels = [], []
|
| 217 |
|
| 218 |
+
# Get tasks from Label Studio
|
| 219 |
+
tasks = self.get_tasks()
|
| 220 |
+
logger.info(f"Retrieved {len(tasks)} tasks from Label Studio")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
# Get interface info
|
| 223 |
+
from_name = 'sentiment' # This matches your label config
|
| 224 |
+
to_name = 'text' # This matches your label config
|
| 225 |
|
| 226 |
for task in tasks:
|
| 227 |
try:
|
|
|
|
| 233 |
|
| 234 |
# Get annotations
|
| 235 |
annotations = task.get('annotations', [])
|
| 236 |
+
if annotations:
|
| 237 |
+
logger.info(f"Found {len(annotations)} annotations for task {task.get('id')}")
|
| 238 |
+
logger.info(f"Annotation content: {json.dumps(annotations[0], indent=2)}")
|
| 239 |
|
| 240 |
if not annotations:
|
| 241 |
logger.warning(f"No annotations found for task {task.get('id')}")
|
| 242 |
continue
|
| 243 |
|
| 244 |
for annotation in annotations:
|
| 245 |
+
# Only use completed annotations
|
| 246 |
+
if annotation.get('was_cancelled') or not annotation.get('completed_by'):
|
| 247 |
+
continue
|
| 248 |
|
| 249 |
+
try:
|
| 250 |
# Get choices from result
|
| 251 |
results = annotation.get('result', [])
|
| 252 |
if not results:
|
|
|
|
| 262 |
texts.append(text)
|
| 263 |
labels.append(label)
|
| 264 |
break
|
| 265 |
+
|
| 266 |
except Exception as e:
|
| 267 |
logger.error(f"Error processing annotation: {str(e)}")
|
| 268 |
continue
|
|
|
|
| 397 |
except Exception as e:
|
| 398 |
logger.error(f"Error saving task: {str(e)}")
|
| 399 |
logger.error("Full error details:", exc_info=True)
|
| 400 |
+
|
| 401 |
+
def connect_to_label_studio(self):
|
| 402 |
+
"""Connect to Label Studio API"""
|
| 403 |
+
try:
|
| 404 |
+
from label_studio_sdk import Client
|
| 405 |
+
|
| 406 |
+
# Get Label Studio connection details from environment
|
| 407 |
+
ls_url = os.getenv('LABEL_STUDIO_URL', 'http://localhost:8080')
|
| 408 |
+
ls_token = os.getenv('LABEL_STUDIO_API_TOKEN')
|
| 409 |
+
|
| 410 |
+
if not ls_token:
|
| 411 |
+
raise ValueError("LABEL_STUDIO_API_TOKEN environment variable is not set")
|
| 412 |
+
|
| 413 |
+
# Initialize client
|
| 414 |
+
client = Client(url=ls_url, api_key=ls_token)
|
| 415 |
+
logger.info(f"Connected to Label Studio at {ls_url}")
|
| 416 |
+
return client
|
| 417 |
+
|
| 418 |
+
except Exception as e:
|
| 419 |
+
logger.error(f"Error connecting to Label Studio: {str(e)}")
|
| 420 |
+
logger.error("Full error details:", exc_info=True)
|
| 421 |
+
raise
|