lawlevisan commited on
Commit
737ac7f
·
verified ·
1 Parent(s): dbbcdc9

Update src/predict.py

Browse files
Files changed (1) hide show
  1. src/predict.py +9 -26
src/predict.py CHANGED
@@ -268,49 +268,32 @@ def validate_and_fix_config(model_path: str) -> bool:
268
  # Enhanced model loading with multiple fallback strategies
269
  # =======================
270
  def load_model_with_fallback(model_name: str) -> bool:
271
- """Simplified model loading for HF Spaces"""
272
  global model, tokenizer, model_loaded
273
 
274
  with model_lock:
275
  if model_loaded:
276
  return True
277
 
278
- logger.info(f"Loading model: {model_name}")
279
 
280
  try:
281
- # Check if local model exists
282
- if os.path.exists(model_name):
283
- logger.info("Loading local model...")
284
- tokenizer = AutoTokenizer.from_pretrained(
285
- model_name,
286
- use_fast=True,
287
- do_lower_case=True,
288
- local_files_only=True # Force local loading
289
- )
290
- model = AutoModelForSequenceClassification.from_pretrained(
291
- model_name,
292
- num_labels=2,
293
- ignore_mismatched_sizes=True,
294
- local_files_only=True # Force local loading
295
- )
296
- else:
297
- # Fallback to a working pre-trained model
298
- logger.warning("Local model not found, using fallback...")
299
- tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
300
- model = AutoModelForSequenceClassification.from_pretrained(
301
- 'distilbert-base-uncased',
302
- num_labels=2
303
- )
304
 
305
  model.to(device)
306
  model.eval()
307
  model_loaded = True
308
- logger.info("Model loaded successfully")
309
  return True
310
 
311
  except Exception as e:
312
  logger.error(f"Model loading failed: {e}")
313
  return False
 
314
  # =======================
315
  # Optimized prediction function with enhanced accuracy
316
  # =======================
 
268
  # Enhanced model loading with multiple fallback strategies
269
  # =======================
270
  def load_model_with_fallback(model_name: str) -> bool:
271
+ """Use standard model - bypass custom model for now"""
272
  global model, tokenizer, model_loaded
273
 
274
  with model_lock:
275
  if model_loaded:
276
  return True
277
 
278
+ logger.info("Using standard DistilBERT model (custom model has tokenizer issues)")
279
 
280
  try:
281
+ tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
282
+ model = AutoModelForSequenceClassification.from_pretrained(
283
+ 'distilbert-base-uncased',
284
+ num_labels=2
285
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  model.to(device)
288
  model.eval()
289
  model_loaded = True
290
+ logger.info("Standard model loaded successfully")
291
  return True
292
 
293
  except Exception as e:
294
  logger.error(f"Model loading failed: {e}")
295
  return False
296
+
297
  # =======================
298
  # Optimized prediction function with enhanced accuracy
299
  # =======================