lawlevisan commited on
Commit
e3d51ef
·
verified ·
1 Parent(s): 35a5d0d

Update src/predict.py

Browse files
Files changed (1) hide show
  1. src/predict.py +19 -89
src/predict.py CHANGED
@@ -268,7 +268,7 @@ def validate_and_fix_config(model_path: str) -> bool:
268
  # Enhanced model loading with multiple fallback strategies
269
  # =======================
270
  def load_model_with_fallback(model_name: str) -> bool:
271
- """Enhanced model loading with multiple fallback strategies"""
272
  global model, tokenizer, model_loaded
273
 
274
  with model_lock:
@@ -277,110 +277,40 @@ def load_model_with_fallback(model_name: str) -> bool:
277
 
278
  logger.info(f"Loading model: {model_name}")
279
 
280
- # Strategy 1: Load local model with Auto classes (most compatible)
281
- if os.path.exists(model_name):
282
- try:
283
- logger.info("Strategy 1: Loading with Auto classes...")
284
  tokenizer = AutoTokenizer.from_pretrained(
285
  model_name,
286
  use_fast=True,
287
- do_lower_case=True
 
288
  )
289
  model = AutoModelForSequenceClassification.from_pretrained(
290
  model_name,
291
  num_labels=2,
292
- ignore_mismatched_sizes=True
293
- )
294
- model.to(device)
295
- model.eval()
296
- logger.info("✅ Successfully loaded with Auto classes")
297
- model_loaded = True
298
- return True
299
- except Exception as e:
300
- logger.error(f"Strategy 1 failed: {e}")
301
-
302
- # Strategy 2: Load with DistilBERT classes
303
- if os.path.exists(model_name):
304
- try:
305
- logger.info("Strategy 2: Loading with DistilBERT classes...")
306
- validate_and_fix_config(model_name)
307
-
308
- tokenizer = DistilBertTokenizerFast.from_pretrained(
309
- model_name,
310
- do_lower_case=True
311
- )
312
- model = DistilBertForSequenceClassification.from_pretrained(
313
- model_name,
314
- ignore_mismatched_sizes=True
315
  )
316
- model.to(device)
317
- model.eval()
318
- logger.info(" Successfully loaded with DistilBERT classes")
319
- model_loaded = True
320
- return True
321
- except Exception as e:
322
- logger.error(f"Strategy 2 failed: {e}")
323
-
324
- # Strategy 3: Create model with custom config + load weights
325
- if os.path.exists(model_name):
326
- try:
327
- logger.info("Strategy 3: Loading with custom configuration...")
328
- config = DistilBertConfig(
329
- vocab_size=30522,
330
- max_position_embeddings=512,
331
- dim=768,
332
- n_layers=6,
333
- n_heads=12,
334
- hidden_dim=3072,
335
- dropout=0.1,
336
- attention_dropout=0.1,
337
- activation='gelu',
338
- num_labels=2,
339
- id2label={0: "NON_DRUG", 1: "DRUG"},
340
- label2id={"NON_DRUG": 0, "DRUG": 1}
341
- )
342
-
343
- tokenizer = DistilBertTokenizerFast.from_pretrained(
344
  'distilbert-base-uncased',
345
- do_lower_case=True
346
  )
347
- model = DistilBertForSequenceClassification(config)
348
-
349
- # Try to load weights
350
- weights_path = os.path.join(model_name, "pytorch_model.bin")
351
- if os.path.exists(weights_path):
352
- state_dict = torch.load(weights_path, map_location=device)
353
- model.load_state_dict(state_dict, strict=False)
354
- logger.info("✅ Loaded custom weights")
355
-
356
- model.to(device)
357
- model.eval()
358
- logger.info("✅ Successfully loaded with custom config")
359
- model_loaded = True
360
- return True
361
- except Exception as e:
362
- logger.error(f"Strategy 3 failed: {e}")
363
-
364
- # Strategy 4: Use pre-trained DistilBERT as fallback
365
- try:
366
- logger.warning("Strategy 4: Falling back to pre-trained DistilBERT...")
367
- tokenizer = DistilBertTokenizerFast.from_pretrained(
368
- 'distilbert-base-uncased',
369
- do_lower_case=True
370
- )
371
- model = DistilBertForSequenceClassification.from_pretrained(
372
- 'distilbert-base-uncased',
373
- num_labels=2
374
- )
375
  model.to(device)
376
  model.eval()
377
- logger.warning("⚠️ Using pre-trained DistilBERT (not fine-tuned for drug detection)")
378
  model_loaded = True
 
379
  return True
 
380
  except Exception as e:
381
- logger.error(f"All strategies failed: {e}")
382
  return False
383
-
384
  # =======================
385
  # Optimized prediction function with enhanced accuracy
386
  # =======================
 
268
  # Enhanced model loading with multiple fallback strategies
269
  # =======================
270
  def load_model_with_fallback(model_name: str) -> bool:
271
+ """Simplified model loading for HF Spaces"""
272
  global model, tokenizer, model_loaded
273
 
274
  with model_lock:
 
277
 
278
  logger.info(f"Loading model: {model_name}")
279
 
280
+ try:
281
+ # Check if local model exists
282
+ if os.path.exists(model_name):
283
+ logger.info("Loading local model...")
284
  tokenizer = AutoTokenizer.from_pretrained(
285
  model_name,
286
  use_fast=True,
287
+ do_lower_case=True,
288
+ local_files_only=True # Force local loading
289
  )
290
  model = AutoModelForSequenceClassification.from_pretrained(
291
  model_name,
292
  num_labels=2,
293
+ ignore_mismatched_sizes=True,
294
+ local_files_only=True # Force local loading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  )
296
+ else:
297
+ # Fallback to a working pre-trained model
298
+ logger.warning("Local model not found, using fallback...")
299
+ tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
300
+ model = AutoModelForSequenceClassification.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  'distilbert-base-uncased',
302
+ num_labels=2
303
  )
304
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  model.to(device)
306
  model.eval()
 
307
  model_loaded = True
308
+ logger.info("Model loaded successfully")
309
  return True
310
+
311
  except Exception as e:
312
+ logger.error(f"Model loading failed: {e}")
313
  return False
 
314
  # =======================
315
  # Optimized prediction function with enhanced accuracy
316
  # =======================