b2u commited on
Commit
0e770ae
·
1 Parent(s): 50e2ec7

initializing tokenizer earlier

Browse files
Files changed (1) hide show
  1. model.py +9 -12
model.py CHANGED
@@ -77,22 +77,20 @@ class BertClassifier(LabelStudioMLBase):
77
  logger.info(f"✓ Training threshold: {self.start_training_threshold}")
78
  logger.info("============================")
79
 
80
- def initialize(self):
81
- """
82
- Initialize model when server starts instead of when first prediction is requested.
83
- """
84
- logger.info("=== INITIALIZING MODEL ===")
85
-
86
- # Initialize model and move to device
87
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
88
- logger.info(f"Using device: {self.device}")
89
-
90
- # Initialize model
91
  self._model = AutoModelForSequenceClassification.from_pretrained(
92
  'bert-base-uncased',
93
  num_labels=len(self.categories)
94
  )
95
- self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
 
 
 
 
 
 
96
 
97
  # Load saved model if exists
98
  model_path = os.path.join(self.model_dir, 'model_state.pt')
@@ -103,7 +101,6 @@ class BertClassifier(LabelStudioMLBase):
103
  except Exception as e:
104
  logger.error(f"Failed to load model: {str(e)}")
105
 
106
- self._model.to(self.device)
107
  logger.info("✓ Model ready")
108
  return self
109
 
 
77
  logger.info(f"✓ Training threshold: {self.start_training_threshold}")
78
  logger.info("============================")
79
 
80
+ # Initialize tokenizer and model architecture (but not weights yet)
 
 
 
 
 
 
81
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
82
+ self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
 
 
83
  self._model = AutoModelForSequenceClassification.from_pretrained(
84
  'bert-base-uncased',
85
  num_labels=len(self.categories)
86
  )
87
+ self._model.to(self.device)
88
+
89
+ def initialize(self):
90
+ """
91
+ Initialize model when server starts instead of when first prediction is requested.
92
+ """
93
+ logger.info("=== INITIALIZING MODEL ===")
94
 
95
  # Load saved model if exists
96
  model_path = os.path.join(self.model_dir, 'model_state.pt')
 
101
  except Exception as e:
102
  logger.error(f"Failed to load model: {str(e)}")
103
 
 
104
  logger.info("✓ Model ready")
105
  return self
106