Spaces:
Sleeping
Sleeping
fixing connection issue
Browse files
model.py
CHANGED
|
@@ -78,16 +78,10 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 78 |
|
| 79 |
# Initialize basic attributes
|
| 80 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 81 |
-
self.version = 'v0.0.1'
|
| 82 |
-
self.model_dir = f'BertClassifier-{self.version}'
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
self.label_studio_client = self.connect_to_label_studio()
|
| 86 |
-
|
| 87 |
-
self.label_encoder = LabelEncoder()
|
| 88 |
-
self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
|
| 89 |
-
|
| 90 |
-
# Define your categories
|
| 91 |
self.categories = [
|
| 92 |
'affiliate_classification', 'brand', 'business_and_career',
|
| 93 |
'content_quality', 'date', 'demographic', 'event',
|
|
@@ -98,8 +92,11 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 98 |
'style_and_fashion', 'no_category'
|
| 99 |
]
|
| 100 |
|
| 101 |
-
#
|
| 102 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
def get_labels(self):
|
| 105 |
li = self.label_interface
|
|
@@ -108,7 +105,31 @@ class BertClassifier(LabelStudioMLBase):
|
|
| 108 |
return tag.labels
|
| 109 |
|
| 110 |
def setup(self):
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
def _lazy_init(self):
|
| 114 |
if not hasattr(self, '_model') or self._model is None:
|
|
|
|
| 78 |
|
| 79 |
# Initialize basic attributes
|
| 80 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 81 |
+
self.version = 'v0.0.1'
|
| 82 |
+
self.model_dir = f'BertClassifier-{self.version}'
|
| 83 |
|
| 84 |
+
# Define categories
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
self.categories = [
|
| 86 |
'affiliate_classification', 'brand', 'business_and_career',
|
| 87 |
'content_quality', 'date', 'demographic', 'event',
|
|
|
|
| 92 |
'style_and_fashion', 'no_category'
|
| 93 |
]
|
| 94 |
|
| 95 |
+
# Initialize model and tokenizer as None - they'll be loaded when needed
|
| 96 |
+
self._model = None
|
| 97 |
+
self.tokenizer = None
|
| 98 |
+
|
| 99 |
+
logger.info("BertClassifier initialized successfully")
|
| 100 |
|
| 101 |
def get_labels(self):
|
| 102 |
li = self.label_interface
|
|
|
|
| 105 |
return tag.labels
|
| 106 |
|
| 107 |
def setup(self):
|
| 108 |
+
"""Setup the model - this is called when Label Studio connects"""
|
| 109 |
+
try:
|
| 110 |
+
# Initialize model directory
|
| 111 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
| 112 |
+
|
| 113 |
+
# Return the required information for Label Studio
|
| 114 |
+
return {
|
| 115 |
+
'model_class': 'BertClassifier', # Must match your class name
|
| 116 |
+
'model_params': {
|
| 117 |
+
'device': str(self.device),
|
| 118 |
+
'version': self.version
|
| 119 |
+
},
|
| 120 |
+
'label_config': {
|
| 121 |
+
'from_name': 'sentiment',
|
| 122 |
+
'to_name': 'text',
|
| 123 |
+
'type': 'choices',
|
| 124 |
+
'labels': self.categories
|
| 125 |
+
},
|
| 126 |
+
'api_version': '2' # Important: specify API version
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"Error in setup: {str(e)}")
|
| 131 |
+
logger.error("Full error details:", exc_info=True)
|
| 132 |
+
raise
|
| 133 |
|
| 134 |
def _lazy_init(self):
|
| 135 |
if not hasattr(self, '_model') or self._model is None:
|