feat(train): add bc5cdr/ncbi_disease aliases; robust dataset loader and retry

#4
by SHA888 - opened
Files changed (1) hide show
  1. app.py +36 -3
app.py CHANGED
@@ -9,7 +9,7 @@ from huggingface_hub import HfApi, create_repo
9
 
10
 
11
  DEFAULT_BASE_MODEL = "dmis-lab/biobert-base-cased-v1.2"
12
- DEFAULT_DATASET = "conll2003" # fallback; medical sets may require custom preprocessing
13
  TARGET_REPO = os.getenv("MEDVLLM_TARGET_REPO", "Junaidi-AI/med-vllm")
14
 
15
 
@@ -29,6 +29,8 @@ def _train_ner_lora(
29
  Minimal LoRA token-classification trainer.
30
  Uses conll2003 by default to be robust in Spaces. Extend to medical datasets later.
31
  """
 
 
32
  from datasets import load_dataset
33
  from transformers import (
34
  AutoTokenizer,
@@ -49,8 +51,39 @@ def _train_ner_lora(
49
 
50
  set_seed(42)
51
 
52
- log(f"Loading dataset: {dataset_name}")
53
- ds = load_dataset(dataset_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  if "train" not in ds:
56
  raise RuntimeError("Dataset must have a train split")
 
9
 
10
 
11
  DEFAULT_BASE_MODEL = "dmis-lab/biobert-base-cased-v1.2"
12
+ DEFAULT_DATASET = "wikiann:en" # robust default; medical sets may require custom preprocessing
13
  TARGET_REPO = os.getenv("MEDVLLM_TARGET_REPO", "Junaidi-AI/med-vllm")
14
 
15
 
 
29
  Minimal LoRA token-classification trainer.
30
  Uses conll2003 by default to be robust in Spaces. Extend to medical datasets later.
31
  """
32
+ # Avoid importing any local dataset scripts even if present in working dir
33
+ os.environ.setdefault("HF_DATASETS_DISABLE_LOCAL_IMPORTS", "1")
34
  from datasets import load_dataset
35
  from transformers import (
36
  AutoTokenizer,
 
51
 
52
  set_seed(42)
53
 
54
+ ds_spec = (dataset_name or "").strip()
55
+ log(f"Loading dataset: {ds_spec}")
56
+ # Support optional config via 'name:config' (e.g., 'wikiann:en')
57
+ try:
58
+ # Medical aliases -> BigBio NER configs
59
+ alias_map = {
60
+ "bc5cdr": ("bigbio/bc5cdr", "bigbio_ner"),
61
+ "ncbi_disease": ("bigbio/ncbi_disease", "bigbio_ner"),
62
+ }
63
+ lower_spec = ds_spec.lower()
64
+ if lower_spec in alias_map:
65
+ ds_name, ds_config = alias_map[lower_spec]
66
+ log(f"Using alias mapping: {ds_spec} -> {ds_name}:{ds_config}")
67
+ ds = load_dataset(ds_name, ds_config)
68
+ elif ":" in ds_spec:
69
+ ds_name, ds_config = [s.strip() for s in ds_spec.split(":", 1)]
70
+ ds = load_dataset(ds_name, ds_config)
71
+ else:
72
+ ds = load_dataset(ds_spec)
73
+ except Exception as e:
74
+ # Fallback: if it looks like 'name:config' but was treated as a local path, try explicit two-arg call
75
+ err_msg = str(e)
76
+ log(f"Dataset load failed: {err_msg}")
77
+ if ":" in ds_spec:
78
+ try:
79
+ ds_name, ds_config = [s.strip() for s in ds_spec.split(":", 1)]
80
+ log(f"Retrying with split name/config: {ds_name}, {ds_config}")
81
+ ds = load_dataset(ds_name, ds_config)
82
+ except Exception as e2:
83
+ log(f"Retry failed: {e2}")
84
+ raise
85
+ else:
86
+ raise
87
 
88
  if "train" not in ds:
89
  raise RuntimeError("Dataset must have a train split")