Plant-Disease-Classifier / src /model_paths.py
mcherif's picture
Download shared model snapshot when missing
d7b1b94
import os
import tempfile
from pathlib import Path
from typing import Optional
from huggingface_hub import snapshot_download
def resolve_model_dir(require_weights: bool = True) -> str:
"""
Determine which vit-finetuned directory to use for inference/training.
Priority order:
1. Explicit MODEL_DIR environment variable.
2. Shared copy in the sibling plant-disease-rag-assistant repo.
3. Local models/vit-finetuned directory in this repository.
If require_weights is True, the returned directory must already contain
model artifacts (model.safetensors). Otherwise, the local directory is
created when missing so that training code can write into it.
"""
override = os.getenv("MODEL_DIR")
if override:
return os.path.abspath(override)
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
shared_dir = _resolve_shared_repo_path(repo_root)
if shared_dir:
return shared_dir
local_dir = os.path.join(repo_root, "models", "vit-finetuned")
if require_weights:
local_model = os.path.join(local_dir, "model.safetensors")
if os.path.exists(local_model):
return local_dir
remote_dir = _download_remote_model()
if remote_dir:
return remote_dir
raise FileNotFoundError(
"model.safetensors not found in shared or local model directories. "
"Set MODEL_DIR to point at a valid model location."
)
os.makedirs(local_dir, exist_ok=True)
return local_dir
def _resolve_shared_repo_path(repo_root: str) -> Optional[str]:
shared_dir = os.path.abspath(
os.path.join(
repo_root,
"..",
"plant-disease-rag-assistant",
"models",
"vit-finetuned",
)
)
shared_model = os.path.join(shared_dir, "model.safetensors")
if os.path.exists(shared_model):
return shared_dir
return None
def _download_remote_model() -> Optional[str]:
repo_id = os.getenv("MODEL_REPO_ID", "mcherif/Plant-Disease-RAG-Assistant")
repo_type = os.getenv("MODEL_REPO_TYPE", "space")
subdir = os.getenv("MODEL_REPO_SUBDIR", "models/vit-finetuned")
if not repo_id:
return None
cache_dir = os.getenv("MODEL_REPO_CACHE_DIR")
if cache_dir:
cache_path = Path(cache_dir)
else:
cache_path = Path(os.getenv("HF_HOME", tempfile.gettempdir()))
cache_path = cache_path / "model_snapshots"
cache_path.mkdir(parents=True, exist_ok=True)
try:
snapshot_path = snapshot_download(
repo_id=repo_id,
repo_type=repo_type,
allow_patterns=[f"{subdir}/*"],
local_dir=str(cache_path),
local_dir_use_symlinks=False,
)
except Exception:
return None
candidate = Path(snapshot_path) / subdir
model_file = candidate / "model.safetensors"
return str(candidate) if model_file.exists() else None