b2u commited on
Commit
41e8855
·
1 Parent(s): b379c6b

webhook fix

Browse files
Files changed (1) hide show
  1. model.py +21 -3
model.py CHANGED
@@ -69,10 +69,28 @@ class BertClassifier(LabelStudioMLBase):
69
  finetuned_model_name = os.getenv('FINETUNED_MODEL_NAME', 'finetuned-model')
70
  _model = None
71
 
72
- def __init__(self, **kwargs):
73
- super().__init__(**kwargs)
74
- # Simplest default - just the placeholder
 
 
 
 
75
  self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def get_labels(self):
78
  li = self.label_interface
 
69
  finetuned_model_name = os.getenv('FINETUNED_MODEL_NAME', 'finetuned-model')
70
  _model = None
71
 
72
+ def __init__(self, project_id=None, label_config=None, **kwargs):
73
+ # Initialize parent class properly
74
+ super(BertClassifier, self).__init__(project_id=project_id, label_config=label_config)
75
+
76
+ # Your existing initialization code
77
+ self.label_encoder = LabelEncoder()
78
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
  self.instruction_template = os.getenv('MODEL_INSTRUCTIONS', '{text}')
80
+
81
+ # Define your categories
82
+ self.categories = [
83
+ 'affiliate_classification', 'brand', 'business_and_career',
84
+ 'content_quality', 'date', 'demographic', 'event',
85
+ 'faith_and_religion', 'gaming', 'health',
86
+ 'internal_categorization', 'location', 'number',
87
+ 'performance', 'post_type', 'pricing_tier',
88
+ 'product', 'profession', 'pii', 'social_network',
89
+ 'style_and_fashion', 'no_category'
90
+ ]
91
+
92
+ # Fit label encoder with your categories
93
+ self.label_encoder.fit(self.categories)
94
 
95
  def get_labels(self):
96
  li = self.label_interface