Spaces:
Runtime error
Runtime error
FEAT: Refactor to download model from HF Hub at runtime
Browse files- .gitignore +5 -0
- src/cnnClassifier/pipeline/prediction.py +15 -16
.gitignore
CHANGED
|
@@ -206,3 +206,8 @@ marimo/_static/
|
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
| 208 |
aws-key.pem
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
| 208 |
aws-key.pem
|
| 209 |
+
model/
|
| 210 |
+
artifacts/
|
| 211 |
+
*.pt
|
| 212 |
+
*.bin
|
| 213 |
+
*.safetensors
|
src/cnnClassifier/pipeline/prediction.py
CHANGED
|
@@ -21,29 +21,28 @@ except ImportError:
|
|
| 21 |
from src.cnnClassifier.utils.common import read_yaml
|
| 22 |
|
| 23 |
class PredictionPipeline:
|
| 24 |
-
def __init__(self,
|
| 25 |
-
self.device = "cpu"
|
| 26 |
-
self.
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
self.
|
| 31 |
-
|
| 32 |
-
'gender_id2label': {'0': 'Male', '1': 'Female'}
|
| 33 |
-
}
|
| 34 |
|
| 35 |
-
|
| 36 |
self.processor = AutoImageProcessor.from_pretrained(self.base_model_name)
|
| 37 |
self.transforms = Compose([Resize((self.params.IMAGE_SIZE, self.params.IMAGE_SIZE)), ToTensor(), Normalize(mean=self.processor.image_mean, std=self.processor.image_std)])
|
| 38 |
self.model = self._load_model()
|
| 39 |
|
| 40 |
-
# --- THE FIX: LOAD BOTH DETECTORS ---
|
| 41 |
-
# High-quality detector for offline tasks
|
| 42 |
-
self.hq_face_detector = MTCNN()
|
| 43 |
-
# Lightweight detector for live feed
|
| 44 |
haar_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
|
| 45 |
-
self.
|
| 46 |
-
# --- END FIX ---
|
| 47 |
|
| 48 |
print(f"--- Pipeline Initialized Successfully on device: {self.device} ---")
|
| 49 |
|
|
|
|
| 21 |
from src.cnnClassifier.utils.common import read_yaml
|
| 22 |
|
| 23 |
class PredictionPipeline:
|
| 24 |
+
def __init__(self, repo_id: str = "ALYYAN/Facial-Age-Det"):
|
| 25 |
+
self.device = "cpu"
|
| 26 |
+
self.repo_id = repo_id
|
| 27 |
+
|
| 28 |
+
print("--- Initializing Prediction Pipeline by downloading artifacts from Hub ---")
|
| 29 |
+
|
| 30 |
+
# --- THE FIX: Download all artifacts from your HF Model Repo ---
|
| 31 |
+
self.model_path = hf_hub_download(repo_id=self.repo_id, filename="checkpoint-26873/model.safetensors")
|
| 32 |
+
self.params_path = hf_hub_download(repo_id=self.repo_id, filename="params.yaml")
|
| 33 |
+
self.data_csv_path = hf_hub_download(repo_id=self.repo_id, filename="fairface_cleaned.csv")
|
| 34 |
+
# --- END FIX ---
|
| 35 |
|
| 36 |
+
self.base_model_name = "google/efficientnet-b2"
|
| 37 |
+
self.params = read_yaml(Path(self.params_path))
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
self.label_maps = self._load_label_maps()
|
| 40 |
self.processor = AutoImageProcessor.from_pretrained(self.base_model_name)
|
| 41 |
self.transforms = Compose([Resize((self.params.IMAGE_SIZE, self.params.IMAGE_SIZE)), ToTensor(), Normalize(mean=self.processor.image_mean, std=self.processor.image_std)])
|
| 42 |
self.model = self._load_model()
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
haar_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
|
| 45 |
+
self.face_detector = cv2.CascadeClassifier(haar_cascade_path)
|
|
|
|
| 46 |
|
| 47 |
print(f"--- Pipeline Initialized Successfully on device: {self.device} ---")
|
| 48 |
|