import os import tempfile from pathlib import Path from typing import Optional from huggingface_hub import snapshot_download def resolve_model_dir(require_weights: bool = True) -> str: """ Determine which vit-finetuned directory to use for inference/training. Priority order: 1. Explicit MODEL_DIR environment variable. 2. Shared copy in the sibling plant-disease-rag-assistant repo. 3. Local models/vit-finetuned directory in this repository. If require_weights is True, the returned directory must already contain model artifacts (model.safetensors). Otherwise, the local directory is created when missing so that training code can write into it. """ override = os.getenv("MODEL_DIR") if override: return os.path.abspath(override) repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) shared_dir = _resolve_shared_repo_path(repo_root) if shared_dir: return shared_dir local_dir = os.path.join(repo_root, "models", "vit-finetuned") if require_weights: local_model = os.path.join(local_dir, "model.safetensors") if os.path.exists(local_model): return local_dir remote_dir = _download_remote_model() if remote_dir: return remote_dir raise FileNotFoundError( "model.safetensors not found in shared or local model directories. " "Set MODEL_DIR to point at a valid model location." ) os.makedirs(local_dir, exist_ok=True) return local_dir def _resolve_shared_repo_path(repo_root: str) -> Optional[str]: shared_dir = os.path.abspath( os.path.join( repo_root, "..", "plant-disease-rag-assistant", "models", "vit-finetuned", ) ) shared_model = os.path.join(shared_dir, "model.safetensors") if os.path.exists(shared_model): return shared_dir return None def _download_remote_model() -> Optional[str]: repo_id = os.getenv("MODEL_REPO_ID", "mcherif/Plant-Disease-RAG-Assistant") repo_type = os.getenv("MODEL_REPO_TYPE", "space") subdir = os.getenv("MODEL_REPO_SUBDIR", "models/vit-finetuned") if not repo_id: return None cache_dir = os.getenv("MODEL_REPO_CACHE_DIR") if cache_dir: cache_path = Path(cache_dir) else: cache_path = Path(os.getenv("HF_HOME", tempfile.gettempdir())) cache_path = cache_path / "model_snapshots" cache_path.mkdir(parents=True, exist_ok=True) try: snapshot_path = snapshot_download( repo_id=repo_id, repo_type=repo_type, allow_patterns=[f"{subdir}/*"], local_dir=str(cache_path), local_dir_use_symlinks=False, ) except Exception: return None candidate = Path(snapshot_path) / subdir model_file = candidate / "model.safetensors" return str(candidate) if model_file.exists() else None