File size: 5,808 Bytes
6e17fd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import bittensor as bt
import datetime
import os
from typing import Dict
from constants import CompetitionParameters
from model.data import Model, ModelId
from model.storage.disk import utils
from model.storage.local_model_store import LocalModelStore
from transformers import AutoModelForCausalLM, AutoTokenizer
from pathlib import Path


class DiskModelStore(LocalModelStore):
    """Local storage based implementation for storing and retrieving a model on disk."""

    def __init__(self, base_dir: str):
        self.base_dir = base_dir
        os.makedirs(utils.get_local_miners_dir(base_dir), exist_ok=True)

    def get_path(self, hotkey: str) -> str:
        """Returns the path to where this store would locate this hotkey."""
        return utils.get_local_miner_dir(self.base_dir, hotkey)

    def store_model(self, hotkey: str, model: Model, hf_model: AutoModelForCausalLM, hf_tokenizer: AutoTokenizer ) -> ModelId:
        """Stores a trained model locally."""
        # get the path to where the model should be stored
        model_dir = os.path.join(self.get_path(hotkey), model.id.name)
        hf_model.save_pretrained(model_dir)
        hf_tokenizer.save_pretrained(model_dir)
        model.local_repo_dir = model_dir
        
        return model.id


    def retrieve_model(
        self, hotkey: str, model_id: ModelId, model_parameters: CompetitionParameters
    ) -> Model:
        """Retrieves a trained model locally."""

        # get the path to where the model should be stored
        model_dir = os.path.join(self.get_path(hotkey), model_id.name)
        return Model(id=model_id, local_repo_dir=model_dir)

    def delete_unreferenced_models(
        self,
        valid_models_by_hotkey: Dict[str, ModelId],
        model_touched_by_hotkey: Dict[str, datetime.datetime],
        grace_period_seconds: int,
    ):
        """Check across all of local storage and delete unreferenced models out of grace period."""
        # TODO: THIS METHOD IS NOT UP TO DATE YET
        raise NotImplementedError("This method is not implemented yet.")
        # Expected directory structure is as follows.
        # self.base_dir/models/hotkey/models--namespace--name/snapshots/commit/config.json + other files.

        # Create a set of valid model paths up to where we expect to see the actual files.
        valid_model_paths = set()
        for hotkey, model_id in valid_models_by_hotkey.items():
            valid_model_paths.add(
                utils.get_local_model_snapshot_dir(self.base_dir, hotkey, model_id)
            )

        # For each hotkey path on disk using listdir to go one level deep.
        miners_dir = Path(utils.get_local_miners_dir(self.base_dir))
        hotkey_subfolder_names = [d.name for d in miners_dir.iterdir() if d.is_dir()]

        for hotkey in hotkey_subfolder_names:
            # Reconstruct the path from the hotkey
            hotkey_path = utils.get_local_miner_dir(self.base_dir, hotkey)

            # If it is not in valid_hotkeys and out of grace period remove it.
            if hotkey not in valid_models_by_hotkey:
                deleted_hotkey = utils.remove_dir_out_of_grace(
                    hotkey_path, grace_period_seconds
                )
                if deleted_hotkey:
                    bt.logging.trace(
                        f"Removed directory for unreferenced hotkey: {hotkey}."
                    )
            else:
                # Check all the models--namespace--name subfolder paths.
                hotkey_dir = Path(hotkey_path)
                model_subfolder_paths = [
                    str(d) for d in hotkey_dir.iterdir() if d.is_dir()
                ]

                # Check all the snapshots subfolder paths
                for model_path in model_subfolder_paths:
                    model_dir = Path(model_path)
                    snapshot_subfolder_paths = [
                        str(d) for d in model_dir.iterdir() if d.is_dir()
                    ]

                    # Check all the commit paths.
                    for snapshot_path in snapshot_subfolder_paths:
                        snapshot_dir = Path(snapshot_path)
                        commit_subfolder_paths = [
                            str(d) for d in snapshot_dir.iterdir() if d.is_dir()
                        ]

                        # Reached the end. Check all the actual commit subfolders for the files.
                        for commit_path in commit_subfolder_paths:
                            if commit_path not in valid_model_paths:
                                deleted_model = utils.remove_dir_out_of_grace(
                                    commit_path, grace_period_seconds
                                )
                                if deleted_model:
                                    bt.logging.trace(
                                        f"Removing directory for unreferenced model at: {commit_path}."
                                    )
                            else:
                                last_touched = model_touched_by_hotkey.get(hotkey)
                                if last_touched is not None:
                                    deleted_model = (
                                        utils.remove_dir_out_of_grace_by_datetime(
                                            commit_path,
                                            grace_period_seconds,
                                            last_touched,
                                        )
                                    )
                                    if deleted_model:
                                        bt.logging.trace(
                                            f"Removing directory for stale model at: {commit_path}."
                                        )