File size: 3,006 Bytes
9625337
d7b1b94
 
 
 
 
9625337
 
efd28a7
9625337
efd28a7
 
 
 
 
 
 
 
 
 
9625337
 
 
 
 
 
 
d7b1b94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9625337
 
 
 
efd28a7
9625337
 
 
 
 
 
 
d7b1b94
9625337
d7b1b94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efd28a7
d7b1b94
 
9625337
d7b1b94
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import tempfile
from pathlib import Path
from typing import Optional

from huggingface_hub import snapshot_download


def resolve_model_dir(require_weights: bool = True) -> str:
    """
    Determine which vit-finetuned directory to use for inference/training.

    Priority order:
      1. Explicit MODEL_DIR environment variable.
      2. Shared copy in the sibling plant-disease-rag-assistant repo.
      3. Local models/vit-finetuned directory in this repository.

    If require_weights is True, the returned directory must already contain
    model artifacts (model.safetensors). Otherwise, the local directory is
    created when missing so that training code can write into it.
    """
    override = os.getenv("MODEL_DIR")
    if override:
        return os.path.abspath(override)

    repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

    shared_dir = _resolve_shared_repo_path(repo_root)
    if shared_dir:
        return shared_dir

    local_dir = os.path.join(repo_root, "models", "vit-finetuned")
    if require_weights:
        local_model = os.path.join(local_dir, "model.safetensors")
        if os.path.exists(local_model):
            return local_dir

        remote_dir = _download_remote_model()
        if remote_dir:
            return remote_dir

        raise FileNotFoundError(
            "model.safetensors not found in shared or local model directories. "
            "Set MODEL_DIR to point at a valid model location."
        )

    os.makedirs(local_dir, exist_ok=True)
    return local_dir


def _resolve_shared_repo_path(repo_root: str) -> Optional[str]:
    shared_dir = os.path.abspath(
        os.path.join(
            repo_root,
            "..",
            "plant-disease-rag-assistant",
            "models",
            "vit-finetuned",
        )
    )
    shared_model = os.path.join(shared_dir, "model.safetensors")
    if os.path.exists(shared_model):
        return shared_dir
    return None


def _download_remote_model() -> Optional[str]:
    repo_id = os.getenv("MODEL_REPO_ID", "mcherif/Plant-Disease-RAG-Assistant")
    repo_type = os.getenv("MODEL_REPO_TYPE", "space")
    subdir = os.getenv("MODEL_REPO_SUBDIR", "models/vit-finetuned")

    if not repo_id:
        return None

    cache_dir = os.getenv("MODEL_REPO_CACHE_DIR")
    if cache_dir:
        cache_path = Path(cache_dir)
    else:
        cache_path = Path(os.getenv("HF_HOME", tempfile.gettempdir()))
        cache_path = cache_path / "model_snapshots"

    cache_path.mkdir(parents=True, exist_ok=True)

    try:
        snapshot_path = snapshot_download(
            repo_id=repo_id,
            repo_type=repo_type,
            allow_patterns=[f"{subdir}/*"],
            local_dir=str(cache_path),
            local_dir_use_symlinks=False,
        )
    except Exception:
        return None

    candidate = Path(snapshot_path) / subdir
    model_file = candidate / "model.safetensors"
    return str(candidate) if model_file.exists() else None