File size: 9,195 Bytes
6e17fd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import copy
import datetime
import threading
from typing import Dict, List, Optional, Set
import pickle
import bittensor as bt
import hashlib
from model.data import ModelMetadata
class NoopLock:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
class ModelTracker:
"""Tracks the current model for each miner.
Thread safe.
"""
def __init__(
self,
thread_safe: bool = True,
):
# Create a dict from miner hotkey to model metadata.
self.miner_hotkey_to_model_metadata_dict: dict[str, ModelMetadata] = dict()
# Create a dict from miner hotkey to last time it was evaluated/loaded/updated
self.miner_hotkey_to_last_touched_dict: dict[str, datetime.datetime] = dict()
# Create a dict from miner hotkey to model hash.
self.miner_hotkey_to_model_hash_dict: dict[str, str] = dict()
# List of overwritten models that may be safe to delete if not curently in use.
self.old_model_metadata: list[tuple[str, ModelMetadata]] = []
# List of model metadata that are currently in use.
self.model_metadata_in_use: set[tuple[str, str]] = set()
# Make this class thread safe because it will be accessed by multiple threads.
# One for the downloading new models loop and one for the validating models loop.
self.lock = threading.RLock() if thread_safe else NoopLock()
def save_state(self, filepath):
"""Save the current state to the provided filepath."""
# Open a writable binary file for pickle.
with self.lock:
with open(filepath, "wb") as f:
pickle.dump(self.miner_hotkey_to_model_metadata_dict, f)
def load_state(self, filepath):
"""Load the state from the provided filepath."""
# Open a readable binary file for pickle.
with open(filepath, "rb") as f:
self.miner_hotkey_to_model_metadata_dict = pickle.load(f)
def get_miner_hotkey_to_model_metadata_dict(self) -> Dict[str, ModelMetadata]:
"""Returns the mapping from miner hotkey to model metadata."""
# Return a copy to ensure outside code can't modify the scores.
with self.lock:
return copy.deepcopy(self.miner_hotkey_to_model_metadata_dict)
def get_model_metadata_for_miner_hotkey(
self, hotkey: str
) -> Optional[ModelMetadata]:
"""Returns the model metadata for a given hotkey if any."""
with self.lock:
if hotkey in self.miner_hotkey_to_model_metadata_dict:
return self.miner_hotkey_to_model_metadata_dict[hotkey]
return None
def take_model_metadata_for_miner_hotkey(self, hotkey: str) -> Optional[ModelMetadata]:
"""Returns the model metadata for a given hotkey if any. Also, marks it as in use to prevent race conditions."""
with self.lock:
if hotkey in self.miner_hotkey_to_model_metadata_dict:
metadata = self.miner_hotkey_to_model_metadata_dict[hotkey]
self.model_metadata_in_use.add((hotkey, metadata.id.hash))
return metadata
return None
def release_all(self):
with self.lock:
self.model_metadata_in_use.clear()
def release_model_metadata_for_miner_hotkey(self, hotkey: str, metadata: ModelMetadata):
with self.lock:
pair = (hotkey, metadata.id.hash)
if pair not in self.model_metadata_in_use:
bt.logging.error("Model metadata is not in use!")
if (hotkey, metadata) in self.old_model_metadata:
bt.logging.trace(f"Releasing old model metadata for hotkey: {hotkey}")
self.model_metadata_in_use.remove(pair)
def get_miner_hotkey_to_last_touched_dict(self) -> Dict[str, datetime.datetime]:
"""Returns the mapping from miner hotkey to last time it was touched."""
# Return a copy to ensure outside code can't modify the scores.
with self.lock:
return copy.deepcopy(self.miner_hotkey_to_last_touched_dict)
def on_hotkeys_updated(self, incoming_hotkeys: Set[str]):
"""Notifies the tracker which hotkeys are currently being tracked on the metagraph."""
with self.lock:
existing_hotkeys = set(self.miner_hotkey_to_model_metadata_dict.keys())
for hotkey in existing_hotkeys - incoming_hotkeys:
del self.miner_hotkey_to_model_metadata_dict[hotkey]
bt.logging.trace(f"Removed outdated hotkey metadata: {hotkey} from ModelTracker")
existing_hotkeys = set(self.miner_hotkey_to_last_touched_dict.keys())
for hotkey in existing_hotkeys - incoming_hotkeys:
del self.miner_hotkey_to_last_touched_dict[hotkey]
bt.logging.trace(f"Removed outdated hotkey timestamp: {hotkey} from ModelTracker")
def get_and_clear_old_models(self) -> list[tuple[str, ModelMetadata]]:
with self.lock:
to_delete = []
still_in_use = []
for hotkey, model in self.old_model_metadata:
if (hotkey, model.id.hash) in self.model_metadata_in_use:
still_in_use.append((hotkey, model))
else:
to_delete.append((hotkey, model))
self.old_model_metadata = still_in_use
return to_delete
def on_miner_model_updated(
self,
hotkey: str,
model_metadata: ModelMetadata,
) -> None:
"""Notifies the tracker that a miner has had their associated model updated.
Args:
hotkey (str): The miner's hotkey.
model_metadata (ModelMetadata): The latest model metadata of the miner.
"""
with self.lock:
if hotkey in self.miner_hotkey_to_model_metadata_dict:
old_metadata = self.miner_hotkey_to_model_metadata_dict[hotkey]
self.old_model_metadata.append((hotkey, old_metadata))
self.miner_hotkey_to_model_metadata_dict[hotkey] = model_metadata
self.miner_hotkey_to_last_touched_dict[hotkey] = datetime.datetime.now()
bt.logging.trace(f"Updated Miner {hotkey}. ModelMetadata={model_metadata}.")
def touch_miner_model(self, hotkey: str) -> None:
"""Notifies the tracker that a miner has been touched."""
now = datetime.datetime.now()
with self.lock:
self.miner_hotkey_to_last_touched_dict[hotkey] = now
bt.logging.trace(f"Touched Miner {hotkey}. datetime={now}.")
def touch_all_miner_models(self) -> None:
"""Touch all miner models."""
now = datetime.datetime.now()
with self.lock:
for hotkey in list(self.miner_hotkey_to_model_metadata_dict.keys()):
self.miner_hotkey_to_last_touched_dict[hotkey] = now
bt.logging.trace(f"Touched All Miners. datetime={now}.")
def update_model_hash(self, hotkey: str, new_model_hash: str) -> bool:
"""
Update the model_hash for a given hotkey.
Args:
hotkey (str): The miner's hotkey.
new_model_hash (str): The new model hash to be set.
Returns:
bool: True if the update was successful, False if the hotkey was not found.
"""
with self.lock:
self.miner_hotkey_to_model_hash_dict[hotkey] = new_model_hash
return True
def calculate_file_hash(self, file_path: str) -> str:
"""Calculate SHA1 hash of a file."""
sha1 = hashlib.sha1()
with open(file_path, 'rb') as f:
while True:
data = f.read(65536) # Read in 64kb chunks
if not data:
break
sha1.update(data)
return sha1.hexdigest()
def is_model_unique(self, hotkey_to_check: str, block_to_check: int, model_checkpoint_path: str) -> bool:
"""Check if a model with a given model_hash is already in use."""
# generate hash from model_checkpoint_path
hash_to_check = self.calculate_file_hash(model_checkpoint_path)
with self.lock:
for hotkey, metadata in self.miner_hotkey_to_model_metadata_dict.items():
if hotkey == hotkey_to_check or hotkey not in self.miner_hotkey_to_model_hash_dict:
continue
if self.miner_hotkey_to_model_hash_dict[hotkey] == hash_to_check and metadata.block < block_to_check:
bt.logging.warning(
f"*** Model with hash {hash_to_check} on block {block_to_check} is not unique. Already in use by {hotkey} on block {metadata.block} for model {metadata.id.namespace}/{metadata.id.name}. ***"
)
# Update the model hash for the hotkey
self.update_model_hash(hotkey_to_check, hash_to_check)
return False, hash_to_check
# Update the model hash for the hotkey
self.update_model_hash(hotkey_to_check, hash_to_check)
return True, hash_to_check
|