yinuozhang commited on
Commit
40e900b
·
verified ·
1 Parent(s): 6a51705

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -46
app.py CHANGED
@@ -1,26 +1,14 @@
1
- # ---- BOOTSTRAP: keep storage under control on Spaces ----
2
- import os, shutil, subprocess
3
- from huggingface_hub import scan_cache_dir, snapshot_download
4
 
5
- # Put caches in /data and make sure dirs exist
6
  os.makedirs("/data/.cache/huggingface/hub", exist_ok=True)
7
  os.makedirs("/data/snapshots", exist_ok=True)
8
-
9
  os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
10
  os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
11
  os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
12
- # Avoid TRANSFORMERS_CACHE deprecation; HF_HOME is enough.
13
- # os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache/huggingface/transformers")
14
-
15
- # Prune old HF cache revisions (safe if empty; now the dir exists)
16
- try:
17
- cache = scan_cache_dir(os.environ["HF_HUB_CACHE"])
18
- if cache.revisions:
19
- cache.delete_revisions([rev for rev in cache.revisions])
20
- except Exception as e:
21
- print(f"[cache prune] skipped: {e}")
22
 
23
- # Light pip cache cleanup
24
  try:
25
  subprocess.run(["pip", "cache", "purge"], check=False)
26
  except Exception:
@@ -30,74 +18,94 @@ except Exception:
30
  import gradio as gr
31
  import sys
32
  import pandas as pd
33
- from transformers import AutoTokenizer, AutoModel, AutoConfig
 
34
 
35
- # Optional: pin commits via Space Variables
36
  MODEL_ID = "ChatterjeeLab/MetaLATTE"
37
  TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D"
38
- MODEL_REV = os.getenv("MODEL_REV", "") # e.g. "a1b2c3d"
39
- TOKENIZER_REV = os.getenv("TOKENIZER_REV", "") # e.g. "9f8e7d6"
40
 
41
  def snapshot_to(local_name, repo_id, revision, allow_patterns):
42
- """Download only needed files into a concrete folder under /data/snapshots."""
43
  local_dir = f"/data/snapshots/{local_name}"
44
  os.makedirs(local_dir, exist_ok=True)
45
- # IMPORTANT: no ignore_regex; use ignore_patterns if needed
46
  return snapshot_download(
47
  repo_id=repo_id,
48
  revision=revision if revision else None,
49
  allow_patterns=allow_patterns,
50
- local_dir=local_dir,
51
- local_dir_use_symlinks=False, # copy files into local_dir; easier to manage size
52
  )
53
 
54
- # Tokenizer (small set of files)
55
  esm_local = snapshot_to(
56
  "esm2_tokenizer",
57
  TOKENIZER_ID,
58
  TOKENIZER_REV,
59
  allow_patterns=[
60
  "tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
61
- "special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken"
 
62
  ],
63
  )
64
 
65
- # MetaLATTE model (weights + config only)
66
  metalatte_local = snapshot_to(
67
  "metalatte_model",
68
  MODEL_ID,
69
  MODEL_REV,
70
- allow_patterns=["*.json","*.safetensors","*.bin","*.model","*.txt"],
71
  )
72
 
73
- # Your local package
74
  metalatte_path = '.'
75
  sys.path.insert(0, metalatte_path)
76
-
77
  from configuration import MetaLATTEConfig
78
  from modeling_metalatte import MultitaskProteinModel
79
- AutoConfig.register("metalatte", MetaLATTEConfig)
80
- AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)
81
 
82
- # Load from the downloaded dirs (no network, no extra cache growth)
83
- tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)
84
  config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)
85
- model = AutoModel.from_pretrained(metalatte_local, config=config, local_files_only=True)
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
 
88
  def predict(sequence):
89
  inputs = tokenizer(sequence, return_tensors="pt")
90
  raw_probs, predictions = model.predict(**inputs)
91
-
92
  id2label = config.id2label
93
- results = {}
94
- for i, pred in enumerate(predictions[0]):
95
- metal = id2label[i]
96
- probability = raw_probs[0][i].item()
97
- results[metal] = '✓' if pred == 1 else ''
98
-
99
- df = pd.DataFrame([results])
100
- return df
101
 
102
  iface = gr.Interface(
103
  fn=predict,
@@ -106,5 +114,4 @@ iface = gr.Interface(
106
  title="MetaLATTE: Metal Binding Prediction",
107
  description="Enter a protein sequence to predict its metal binding properties."
108
  )
109
-
110
- iface.launch()
 
1
+ # ---- BOOTSTRAP: stable cache to /data, minimal downloads ----
2
+ import os, subprocess
3
+ from huggingface_hub import snapshot_download
4
 
 
5
  os.makedirs("/data/.cache/huggingface/hub", exist_ok=True)
6
  os.makedirs("/data/snapshots", exist_ok=True)
 
7
  os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
8
  os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
9
  os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Optional: keep pip cache small
12
  try:
13
  subprocess.run(["pip", "cache", "purge"], check=False)
14
  except Exception:
 
18
  import gradio as gr
19
  import sys
20
  import pandas as pd
21
+ import torch
22
+ from transformers import AutoTokenizer, AutoConfig
23
 
24
+ # Pin via Space → Settings → Variables if you want (helps avoid repeated downloads)
25
  MODEL_ID = "ChatterjeeLab/MetaLATTE"
26
  TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D"
27
+ MODEL_REV = os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401") # from your screenshot
28
+ TOKENIZER_REV = os.getenv("TOKENIZER_REV", "")
29
 
30
  def snapshot_to(local_name, repo_id, revision, allow_patterns):
 
31
  local_dir = f"/data/snapshots/{local_name}"
32
  os.makedirs(local_dir, exist_ok=True)
 
33
  return snapshot_download(
34
  repo_id=repo_id,
35
  revision=revision if revision else None,
36
  allow_patterns=allow_patterns,
37
+ local_dir=local_dir, # new hub ignores symlink flag; this is enough
 
38
  )
39
 
40
+ # Download tokenizer files (small)
41
  esm_local = snapshot_to(
42
  "esm2_tokenizer",
43
  TOKENIZER_ID,
44
  TOKENIZER_REV,
45
  allow_patterns=[
46
  "tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
47
+ "special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken",
48
+ "config.json" # some tokenizers use it
49
  ],
50
  )
51
 
52
+ # Download MetaLATTE weights + config ONLY (skip stage1 blob)
53
  metalatte_local = snapshot_to(
54
  "metalatte_model",
55
  MODEL_ID,
56
  MODEL_REV,
57
+ allow_patterns=["config.json", "pytorch_model.bin"],
58
  )
59
 
60
+ # Your local custom code
61
  metalatte_path = '.'
62
  sys.path.insert(0, metalatte_path)
 
63
  from configuration import MetaLATTEConfig
64
  from modeling_metalatte import MultitaskProteinModel
 
 
65
 
66
+ # Load config + instantiate model (no network)
 
67
  config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)
 
68
 
69
+ # Find the weight file locally
70
+ weight_candidates = [
71
+ "pytorch_model.bin",
72
+ "model/pytorch_model.bin",
73
+ "model.safetensors",
74
+ "model/model.safetensors",
75
+ "stage1_model.bin",
76
+ "model/stage1_model.bin",
77
+ ]
78
+ weight_path = None
79
+ for c in weight_candidates:
80
+ p = os.path.join(metalatte_local, c)
81
+ if os.path.exists(p):
82
+ weight_path = p
83
+ break
84
+ if weight_path is None:
85
+ raise FileNotFoundError(f"No weights found in {metalatte_local}. Looked for: {weight_candidates}")
86
+
87
+ # Build model and load the local state dict
88
+ model = MultitaskProteinModel(config)
89
+ if weight_path.endswith(".safetensors"):
90
+ from safetensors.torch import load_file
91
+ state_dict = load_file(weight_path, device="cpu")
92
+ else:
93
+ state_dict = torch.load(weight_path, map_location="cpu")
94
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
95
+ if missing or unexpected:
96
+ print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}")
97
+ model.eval()
98
+
99
+ # Tokenizer
100
+ tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)
101
 
102
+ @torch.inference_mode()
103
  def predict(sequence):
104
  inputs = tokenizer(sequence, return_tensors="pt")
105
  raw_probs, predictions = model.predict(**inputs)
 
106
  id2label = config.id2label
107
+ row = {id2label[i]: ('✓' if int(pred) == 1 else '') for i, pred in enumerate(predictions[0])}
108
+ return pd.DataFrame([row])
 
 
 
 
 
 
109
 
110
  iface = gr.Interface(
111
  fn=predict,
 
114
  title="MetaLATTE: Metal Binding Prediction",
115
  description="Enter a protein sequence to predict its metal binding properties."
116
  )
117
+ iface.launch()