Dev Nagaich commited on
Commit
6f575dc
·
1 Parent(s): 2751713

Fix: Simplify Dockerfile and add runtime weight download to model_server

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -24
  2. model_server.py +39 -7
Dockerfile CHANGED
@@ -38,30 +38,8 @@ COPY app.py model_server.py ./
38
  RUN mkdir -p .streamlit
39
  COPY .streamlit/config.toml .streamlit/
40
 
41
- # Download VREyeSAM fine-tuned weights from Hugging Face
42
- # This assumes your weights are available on HF Model Hub
43
- RUN if [ ! -z "$HF_TOKEN" ]; then \
44
- python -c "
45
- from huggingface_hub import hf_hub_download
46
- import os
47
-
48
- hf_token = os.getenv('HF_TOKEN')
49
- if hf_token:
50
- # Download VREyeSAM weights from your HF repo
51
- # Replace 'your-username/vreyesam' with your actual repo
52
- try:
53
- checkpoint = hf_hub_download(
54
- repo_id='devnagaich/VREyeSAM',
55
- filename='VREyeSAM_uncertainity_best.torch',
56
- token=hf_token,
57
- cache_dir='segment-anything-2/checkpoints'
58
- )
59
- print(f'Downloaded VREyeSAM weights to: {checkpoint}')
60
- except Exception as e:
61
- print(f'Warning: Could not download VREyeSAM weights: {e}')
62
- print('App will attempt to load from cache at runtime')
63
- " || true; \
64
- fi
65
 
66
  # Expose Streamlit port
67
  EXPOSE 7860
 
38
  RUN mkdir -p .streamlit
39
  COPY .streamlit/config.toml .streamlit/
40
 
41
+ # Note: VREyeSAM fine-tuned weights will be downloaded at runtime by model_server.py
42
+ # using the HF_TOKEN from HF Spaces Secrets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Expose Streamlit port
45
  EXPOSE 7860
model_server.py CHANGED
@@ -17,22 +17,54 @@ from typing import Tuple, Optional
17
  def get_model_checkpoint_path():
18
  """Get checkpoint path secretly, never expose to client"""
19
  base_dir = Path(__file__).parent
20
- checkpoint = base_dir / "segment-anything-2" / "checkpoints" / "sam2_hiera_small.pt"
21
  if not checkpoint.exists():
22
  raise FileNotFoundError(f"Model checkpoint not found")
23
  return str(checkpoint)
24
 
25
  def get_finetuned_weights_path():
26
- """Get fine-tuned weights path secretly, never expose to client"""
 
 
 
 
27
  base_dir = Path(__file__).parent
28
- weights = base_dir / "segment-anything-2" / "checkpoints" / "VREyeSAM_uncertainity_best.torch"
29
- if not weights.exists():
30
- raise FileNotFoundError(f"Fine-tuned weights not found")
31
- return str(weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def get_model_config_path():
34
  """Get model config path secretly, never expose to client"""
35
- return "configs/sam2/sam2_hiera_s.yaml"
36
 
37
 
38
  class ProtectedModelServer:
 
17
  def get_model_checkpoint_path():
18
  """Get checkpoint path secretly, never expose to client"""
19
  base_dir = Path(__file__).parent
20
+ checkpoint = base_dir / "segment-anything-2" / "checkpoints" / "sam2.1_hiera_small.pt"
21
  if not checkpoint.exists():
22
  raise FileNotFoundError(f"Model checkpoint not found")
23
  return str(checkpoint)
24
 
25
  def get_finetuned_weights_path():
26
+ """Get fine-tuned weights path secretly, never expose to client
27
+
28
+ Attempts to download from Hugging Face if local copy doesn't exist
29
+ and HF_TOKEN is available.
30
+ """
31
  base_dir = Path(__file__).parent
32
+ checkpoint_dir = base_dir / "segment-anything-2" / "checkpoints"
33
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
34
+ weights = checkpoint_dir / "VREyeSAM_uncertainity_best.torch"
35
+
36
+ # If weights already exist locally, return path
37
+ if weights.exists():
38
+ return str(weights)
39
+
40
+ # Try to download from Hugging Face using HF_TOKEN
41
+ hf_token = os.getenv('HF_TOKEN', '')
42
+ if hf_token:
43
+ try:
44
+ from huggingface_hub import hf_hub_download
45
+ print("Downloading VREyeSAM weights from Hugging Face...")
46
+
47
+ checkpoint_path = hf_hub_download(
48
+ repo_id='devnagaich/VREyeSAM',
49
+ filename='VREyeSAM_uncertainity_best.torch',
50
+ token=hf_token,
51
+ cache_dir=str(checkpoint_dir)
52
+ )
53
+ print(f"Successfully downloaded VREyeSAM weights")
54
+ return checkpoint_path
55
+ except Exception as e:
56
+ print(f"Warning: Could not download VREyeSAM weights: {e}")
57
+
58
+ # If download fails or no token, return path anyway (may exist from upload)
59
+ if weights.exists():
60
+ return str(weights)
61
+
62
+ # Last resort - raise error
63
+ raise FileNotFoundError(f"VREyeSAM weights not found and could not download")
64
 
65
  def get_model_config_path():
66
  """Get model config path secretly, never expose to client"""
67
+ return "configs/sam2.1/sam2.1_hiera_s.yaml"
68
 
69
 
70
  class ProtectedModelServer: