Maximilian Schuh commited on
Commit
e174389
·
1 Parent(s): e19f3e1

Added local model weights

Browse files
Files changed (3) hide show
  1. app.py +19 -2
  2. weights/bt_model.tar.xz +3 -0
  3. weights/lgbm_model.tar.xz +3 -0
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import tempfile
3
  from pathlib import Path
4
  from typing import List, Tuple
@@ -16,6 +17,7 @@ HF_BT_REPO = os.environ.get("HF_BT_REPO", "mschuh/TwinBooster-bt")
16
 
17
  # Store models where TwinBooster looks for them by default (~/.cache/twinbooster)
18
  MODEL_DIR = Path.home() / ".cache" / "twinbooster"
 
19
 
20
 
21
  def ensure_models() -> str:
@@ -40,11 +42,26 @@ def ensure_models() -> str:
40
  except Exception:
41
  return False
42
 
 
 
 
 
 
 
 
 
 
43
  if grab(HF_TEXT_REPO, "PubChemDeBERTa-augmented"):
44
  downloaded.append(HF_TEXT_REPO)
45
- if HF_LGBM_REPO and grab(HF_LGBM_REPO, "lgbm_model"):
 
 
 
46
  downloaded.append(HF_LGBM_REPO)
47
- if HF_BT_REPO and grab(HF_BT_REPO, "bt_model"):
 
 
 
48
  downloaded.append(HF_BT_REPO)
49
 
50
  # Ensure any missing pieces are resolved via package helper (will skip if already present)
 
1
  import os
2
+ import tarfile
3
  import tempfile
4
  from pathlib import Path
5
  from typing import List, Tuple
 
17
 
18
  # Store models where TwinBooster looks for them by default (~/.cache/twinbooster)
19
  MODEL_DIR = Path.home() / ".cache" / "twinbooster"
20
+ WEIGHTS_SRC = Path(__file__).parent / "weights"
21
 
22
 
23
  def ensure_models() -> str:
 
42
  except Exception:
43
  return False
44
 
45
+ def extract_local(archive: Path, subdir: str) -> bool:
46
+ target = MODEL_DIR / subdir
47
+ if target.exists() or not archive.exists():
48
+ return False
49
+ target.mkdir(parents=True, exist_ok=True)
50
+ with tarfile.open(archive, "r:*") as tf:
51
+ tf.extractall(target)
52
+ return True
53
+
54
  if grab(HF_TEXT_REPO, "PubChemDeBERTa-augmented"):
55
  downloaded.append(HF_TEXT_REPO)
56
+
57
+ if extract_local(WEIGHTS_SRC / "lgbm_model.tar.xz", "lgbm_model"):
58
+ downloaded.append("local lgbm_model.tar.xz")
59
+ elif HF_LGBM_REPO and grab(HF_LGBM_REPO, "lgbm_model"):
60
  downloaded.append(HF_LGBM_REPO)
61
+
62
+ if extract_local(WEIGHTS_SRC / "bt_model.tar.xz", "bt_model"):
63
+ downloaded.append("local bt_model.tar.xz")
64
+ elif HF_BT_REPO and grab(HF_BT_REPO, "bt_model"):
65
  downloaded.append(HF_BT_REPO)
66
 
67
  # Ensure any missing pieces are resolved via package helper (will skip if already present)
weights/bt_model.tar.xz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70607f078e00861c92e36c12a873a39e7e54842285653209ed81ac02337662b0
3
+ size 466955972
weights/lgbm_model.tar.xz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e599e171a9cba778b9d858e283ee874addf9513ee97765b2890c729ac59a6bd3
3
+ size 23041344