Spaces:
Sleeping
Sleeping
File size: 7,488 Bytes
f74cf62 6f575dc f74cf62 6f575dc f74cf62 6f575dc f74cf62 6f575dc f74cf62 35351a6 f74cf62 35351a6 f74cf62 35351a6 f74cf62 35351a6 f74cf62 35351a6 f74cf62 35351a6 f74cf62 35351a6 f74cf62 35351a6 f74cf62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | """
Secure Model Server - Protects model weights from extraction
Never expose:
- File paths to checkpoints
- Model architecture details
- Debug routes
"""
import os
import sys
import torch
import numpy as np
from pathlib import Path
from typing import Tuple, Optional
# Secure path resolution (not hardcoded)
def get_model_checkpoint_path():
"""Get checkpoint path secretly, never expose to client"""
base_dir = Path(__file__).parent
checkpoint = base_dir / "segment-anything-2" / "checkpoints" / "sam2.1_hiera_small.pt"
if not checkpoint.exists():
raise FileNotFoundError(f"Model checkpoint not found")
return str(checkpoint)
def get_finetuned_weights_path():
"""Get fine-tuned weights path secretly, never expose to client
Attempts to download from Hugging Face if local copy doesn't exist
and HF_TOKEN is available.
"""
base_dir = Path(__file__).parent
checkpoint_dir = base_dir / "segment-anything-2" / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
weights = checkpoint_dir / "VREyeSAM_uncertainity_best.torch"
# If weights already exist locally, return path
if weights.exists():
return str(weights)
# Try to download from Hugging Face using HF_TOKEN
hf_token = os.getenv('HF_TOKEN', '')
if hf_token:
try:
from huggingface_hub import hf_hub_download
print("Downloading VREyeSAM weights from Hugging Face...")
checkpoint_path = hf_hub_download(
repo_id='devnagaich/VREyeSAM',
filename='VREyeSAM_uncertainity_best.torch',
token=hf_token,
cache_dir=str(checkpoint_dir)
)
print(f"Successfully downloaded VREyeSAM weights")
return checkpoint_path
except Exception as e:
print(f"Warning: Could not download VREyeSAM weights: {e}")
# If download fails or no token, return path anyway (may exist from upload)
if weights.exists():
return str(weights)
# Last resort - raise error
raise FileNotFoundError(f"VREyeSAM weights not found and could not download")
def get_model_config_path():
"""Get model config path secretly, never expose to client"""
return "configs/sam2.1/sam2.1_hiera_s.yaml"
class ProtectedModelServer:
"""
Encapsulates model loading and inference
Only exposes inference API, never raw weights or paths
"""
_instance = None # Singleton pattern
_model = None
_predictor = None
def __new__(cls):
# Singleton: only one instance ever
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""Initialize model (only once)"""
if self._predictor is None:
self._load_model()
def _load_model(self):
"""Load model weights securely - never called from frontend"""
try:
# Add segment-anything-2 to path (internally only)
base_dir = Path(__file__).parent
sam2_path = base_dir / "segment-anything-2"
if not sam2_path.exists():
raise FileNotFoundError(f"SAM2 installation not found at {sam2_path}")
sys.path.insert(0, str(sam2_path))
try:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
except ImportError as e:
raise ImportError("SAM2 not properly installed. Check build logs.") from e
# Get paths internally - NEVER sent to client
model_cfg = get_model_config_path()
sam2_checkpoint = get_model_checkpoint_path()
# Load device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
# Load base SAM2 model
print(f"Loading SAM2 from {sam2_checkpoint}")
self._model = build_sam2(model_cfg, sam2_checkpoint, device=device)
self._predictor = SAM2ImagePredictor(self._model)
# Try to load fine-tuned weights if available
try:
fine_tuned_weights = get_finetuned_weights_path()
print(f"Loading fine-tuned weights from {fine_tuned_weights}")
state_dict = torch.load(fine_tuned_weights, map_location=device)
self._predictor.model.load_state_dict(state_dict)
print("Fine-tuned weights loaded successfully")
except FileNotFoundError:
print("Warning: Fine-tuned weights not found. Using base SAM2 model.")
print("To use fine-tuned model, upload VREyeSAM_uncertainity_best.torch to Space Files")
except Exception as e:
print(f"Warning: Could not load fine-tuned weights: {e}")
print("Continuing with base SAM2 model")
# Model is now loaded - weights are NOT accessible to clients
self._predictor.model.eval()
print("Model loaded successfully")
return True
except Exception as e:
print(f"Error loading model: {e}")
import traceback
traceback.print_exc()
raise RuntimeError(f"Model initialization failed: {str(e)}") from e
def predict(self, image: np.ndarray, num_samples: int = 30) -> Tuple[np.ndarray, np.ndarray]:
"""
Perform iris segmentation
Args:
image: Input image (numpy array)
num_samples: Number of random points for inference
Returns:
binary_mask: Binary segmentation mask
prob_mask: Probability map
"""
if self._predictor is None:
raise RuntimeError("Model not initialized")
try:
# Generate random points for inference
input_points = np.random.randint(0, min(image.shape[:2]), (num_samples, 1, 2))
# Inference
with torch.no_grad():
self._predictor.set_image(image)
masks, scores, _ = self._predictor.predict(
point_coords=input_points,
point_labels=np.ones([input_points.shape[0], 1])
)
# Convert to numpy
np_masks = np.array(masks[:, 0]).astype(np.float32)
np_scores = scores[:, 0]
# Normalize scores
score_sum = np.sum(np_scores)
if score_sum > 0:
normalized_scores = np_scores / score_sum
else:
normalized_scores = np.ones_like(np_scores) / len(np_scores)
# Generate probabilistic mask
prob_mask = np.sum(np_masks * normalized_scores[:, None, None], axis=0)
prob_mask = np.clip(prob_mask, 0, 1)
# Threshold to get binary mask
binary_mask = (prob_mask > 0.2).astype(np.uint8)
return binary_mask, prob_mask
except Exception as e:
raise RuntimeError(f"Inference failed") from e
def get_predictor() -> ProtectedModelServer:
"""Get singleton model instance"""
return ProtectedModelServer()
|