Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Optional | |
| from huggingface_hub import snapshot_download | |
| def resolve_model_dir(require_weights: bool = True) -> str: | |
| """ | |
| Determine which vit-finetuned directory to use for inference/training. | |
| Priority order: | |
| 1. Explicit MODEL_DIR environment variable. | |
| 2. Shared copy in the sibling plant-disease-rag-assistant repo. | |
| 3. Local models/vit-finetuned directory in this repository. | |
| If require_weights is True, the returned directory must already contain | |
| model artifacts (model.safetensors). Otherwise, the local directory is | |
| created when missing so that training code can write into it. | |
| """ | |
| override = os.getenv("MODEL_DIR") | |
| if override: | |
| return os.path.abspath(override) | |
| repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| shared_dir = _resolve_shared_repo_path(repo_root) | |
| if shared_dir: | |
| return shared_dir | |
| local_dir = os.path.join(repo_root, "models", "vit-finetuned") | |
| if require_weights: | |
| local_model = os.path.join(local_dir, "model.safetensors") | |
| if os.path.exists(local_model): | |
| return local_dir | |
| remote_dir = _download_remote_model() | |
| if remote_dir: | |
| return remote_dir | |
| raise FileNotFoundError( | |
| "model.safetensors not found in shared or local model directories. " | |
| "Set MODEL_DIR to point at a valid model location." | |
| ) | |
| os.makedirs(local_dir, exist_ok=True) | |
| return local_dir | |
| def _resolve_shared_repo_path(repo_root: str) -> Optional[str]: | |
| shared_dir = os.path.abspath( | |
| os.path.join( | |
| repo_root, | |
| "..", | |
| "plant-disease-rag-assistant", | |
| "models", | |
| "vit-finetuned", | |
| ) | |
| ) | |
| shared_model = os.path.join(shared_dir, "model.safetensors") | |
| if os.path.exists(shared_model): | |
| return shared_dir | |
| return None | |
| def _download_remote_model() -> Optional[str]: | |
| repo_id = os.getenv("MODEL_REPO_ID", "mcherif/Plant-Disease-RAG-Assistant") | |
| repo_type = os.getenv("MODEL_REPO_TYPE", "space") | |
| subdir = os.getenv("MODEL_REPO_SUBDIR", "models/vit-finetuned") | |
| if not repo_id: | |
| return None | |
| cache_dir = os.getenv("MODEL_REPO_CACHE_DIR") | |
| if cache_dir: | |
| cache_path = Path(cache_dir) | |
| else: | |
| cache_path = Path(os.getenv("HF_HOME", tempfile.gettempdir())) | |
| cache_path = cache_path / "model_snapshots" | |
| cache_path.mkdir(parents=True, exist_ok=True) | |
| try: | |
| snapshot_path = snapshot_download( | |
| repo_id=repo_id, | |
| repo_type=repo_type, | |
| allow_patterns=[f"{subdir}/*"], | |
| local_dir=str(cache_path), | |
| local_dir_use_symlinks=False, | |
| ) | |
| except Exception: | |
| return None | |
| candidate = Path(snapshot_path) / subdir | |
| model_file = candidate / "model.safetensors" | |
| return str(candidate) if model_file.exists() else None | |