b2u commited on
Commit
c352a3a
·
1 Parent(s): 3178a98

fixing connection issue

Browse files
Files changed (1) hide show
  1. model.py +33 -12
model.py CHANGED
@@ -78,16 +78,10 @@ class BertClassifier(LabelStudioMLBase):
78
 
79
  # Initialize basic attributes
80
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
- self.version = 'v0.0.1' # Define version explicitly
82
- self.model_dir = f'BertClassifier-{self.version}' # Use versioned model directory
83
 
84
- # Initialize Label Studio client
85
- self.label_studio_client = self.connect_to_label_studio()
86
-
87
- self.label_encoder = LabelEncoder()
88
- self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
89
-
90
- # Define your categories
91
  self.categories = [
92
  'affiliate_classification', 'brand', 'business_and_career',
93
  'content_quality', 'date', 'demographic', 'event',
@@ -98,8 +92,11 @@ class BertClassifier(LabelStudioMLBase):
98
  'style_and_fashion', 'no_category'
99
  ]
100
 
101
- # Fit label encoder with your categories
102
- self.label_encoder.fit(self.categories)
 
 
 
103
 
104
  def get_labels(self):
105
  li = self.label_interface
@@ -108,7 +105,31 @@ class BertClassifier(LabelStudioMLBase):
108
  return tag.labels
109
 
110
  def setup(self):
111
- self.set("model_version", f'{self.__class__.__name__}-v0.0.1')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def _lazy_init(self):
114
  if not hasattr(self, '_model') or self._model is None:
 
78
 
79
  # Initialize basic attributes
80
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
+ self.version = 'v0.0.1'
82
+ self.model_dir = f'BertClassifier-{self.version}'
83
 
84
+ # Define categories
 
 
 
 
 
 
85
  self.categories = [
86
  'affiliate_classification', 'brand', 'business_and_career',
87
  'content_quality', 'date', 'demographic', 'event',
 
92
  'style_and_fashion', 'no_category'
93
  ]
94
 
95
+ # Initialize model and tokenizer as None - they'll be loaded when needed
96
+ self._model = None
97
+ self.tokenizer = None
98
+
99
+ logger.info("BertClassifier initialized successfully")
100
 
101
  def get_labels(self):
102
  li = self.label_interface
 
105
  return tag.labels
106
 
107
  def setup(self):
108
+ """Setup the model - this is called when Label Studio connects"""
109
+ try:
110
+ # Initialize model directory
111
+ os.makedirs(self.model_dir, exist_ok=True)
112
+
113
+ # Return the required information for Label Studio
114
+ return {
115
+ 'model_class': 'BertClassifier', # Must match your class name
116
+ 'model_params': {
117
+ 'device': str(self.device),
118
+ 'version': self.version
119
+ },
120
+ 'label_config': {
121
+ 'from_name': 'sentiment',
122
+ 'to_name': 'text',
123
+ 'type': 'choices',
124
+ 'labels': self.categories
125
+ },
126
+ 'api_version': '2' # Important: specify API version
127
+ }
128
+
129
+ except Exception as e:
130
+ logger.error(f"Error in setup: {str(e)}")
131
+ logger.error("Full error details:", exc_info=True)
132
+ raise
133
 
134
  def _lazy_init(self):
135
  if not hasattr(self, '_model') or self._model is None: