trash / model /model_tracker.py
Shinichie's picture
Upload folder using huggingface_hub
6e17fd0 verified
import copy
import datetime
import threading
from typing import Dict, List, Optional, Set
import pickle
import bittensor as bt
import hashlib
from model.data import ModelMetadata
class NoopLock:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
class ModelTracker:
"""Tracks the current model for each miner.
Thread safe.
"""
def __init__(
self,
thread_safe: bool = True,
):
# Create a dict from miner hotkey to model metadata.
self.miner_hotkey_to_model_metadata_dict: dict[str, ModelMetadata] = dict()
# Create a dict from miner hotkey to last time it was evaluated/loaded/updated
self.miner_hotkey_to_last_touched_dict: dict[str, datetime.datetime] = dict()
# Create a dict from miner hotkey to model hash.
self.miner_hotkey_to_model_hash_dict: dict[str, str] = dict()
# List of overwritten models that may be safe to delete if not curently in use.
self.old_model_metadata: list[tuple[str, ModelMetadata]] = []
# List of model metadata that are currently in use.
self.model_metadata_in_use: set[tuple[str, str]] = set()
# Make this class thread safe because it will be accessed by multiple threads.
# One for the downloading new models loop and one for the validating models loop.
self.lock = threading.RLock() if thread_safe else NoopLock()
def save_state(self, filepath):
"""Save the current state to the provided filepath."""
# Open a writable binary file for pickle.
with self.lock:
with open(filepath, "wb") as f:
pickle.dump(self.miner_hotkey_to_model_metadata_dict, f)
def load_state(self, filepath):
"""Load the state from the provided filepath."""
# Open a readable binary file for pickle.
with open(filepath, "rb") as f:
self.miner_hotkey_to_model_metadata_dict = pickle.load(f)
def get_miner_hotkey_to_model_metadata_dict(self) -> Dict[str, ModelMetadata]:
"""Returns the mapping from miner hotkey to model metadata."""
# Return a copy to ensure outside code can't modify the scores.
with self.lock:
return copy.deepcopy(self.miner_hotkey_to_model_metadata_dict)
def get_model_metadata_for_miner_hotkey(
self, hotkey: str
) -> Optional[ModelMetadata]:
"""Returns the model metadata for a given hotkey if any."""
with self.lock:
if hotkey in self.miner_hotkey_to_model_metadata_dict:
return self.miner_hotkey_to_model_metadata_dict[hotkey]
return None
def take_model_metadata_for_miner_hotkey(self, hotkey: str) -> Optional[ModelMetadata]:
"""Returns the model metadata for a given hotkey if any. Also, marks it as in use to prevent race conditions."""
with self.lock:
if hotkey in self.miner_hotkey_to_model_metadata_dict:
metadata = self.miner_hotkey_to_model_metadata_dict[hotkey]
self.model_metadata_in_use.add((hotkey, metadata.id.hash))
return metadata
return None
def release_all(self):
with self.lock:
self.model_metadata_in_use.clear()
def release_model_metadata_for_miner_hotkey(self, hotkey: str, metadata: ModelMetadata):
with self.lock:
pair = (hotkey, metadata.id.hash)
if pair not in self.model_metadata_in_use:
bt.logging.error("Model metadata is not in use!")
if (hotkey, metadata) in self.old_model_metadata:
bt.logging.trace(f"Releasing old model metadata for hotkey: {hotkey}")
self.model_metadata_in_use.remove(pair)
def get_miner_hotkey_to_last_touched_dict(self) -> Dict[str, datetime.datetime]:
"""Returns the mapping from miner hotkey to last time it was touched."""
# Return a copy to ensure outside code can't modify the scores.
with self.lock:
return copy.deepcopy(self.miner_hotkey_to_last_touched_dict)
def on_hotkeys_updated(self, incoming_hotkeys: Set[str]):
"""Notifies the tracker which hotkeys are currently being tracked on the metagraph."""
with self.lock:
existing_hotkeys = set(self.miner_hotkey_to_model_metadata_dict.keys())
for hotkey in existing_hotkeys - incoming_hotkeys:
del self.miner_hotkey_to_model_metadata_dict[hotkey]
bt.logging.trace(f"Removed outdated hotkey metadata: {hotkey} from ModelTracker")
existing_hotkeys = set(self.miner_hotkey_to_last_touched_dict.keys())
for hotkey in existing_hotkeys - incoming_hotkeys:
del self.miner_hotkey_to_last_touched_dict[hotkey]
bt.logging.trace(f"Removed outdated hotkey timestamp: {hotkey} from ModelTracker")
def get_and_clear_old_models(self) -> list[tuple[str, ModelMetadata]]:
with self.lock:
to_delete = []
still_in_use = []
for hotkey, model in self.old_model_metadata:
if (hotkey, model.id.hash) in self.model_metadata_in_use:
still_in_use.append((hotkey, model))
else:
to_delete.append((hotkey, model))
self.old_model_metadata = still_in_use
return to_delete
def on_miner_model_updated(
self,
hotkey: str,
model_metadata: ModelMetadata,
) -> None:
"""Notifies the tracker that a miner has had their associated model updated.
Args:
hotkey (str): The miner's hotkey.
model_metadata (ModelMetadata): The latest model metadata of the miner.
"""
with self.lock:
if hotkey in self.miner_hotkey_to_model_metadata_dict:
old_metadata = self.miner_hotkey_to_model_metadata_dict[hotkey]
self.old_model_metadata.append((hotkey, old_metadata))
self.miner_hotkey_to_model_metadata_dict[hotkey] = model_metadata
self.miner_hotkey_to_last_touched_dict[hotkey] = datetime.datetime.now()
bt.logging.trace(f"Updated Miner {hotkey}. ModelMetadata={model_metadata}.")
def touch_miner_model(self, hotkey: str) -> None:
"""Notifies the tracker that a miner has been touched."""
now = datetime.datetime.now()
with self.lock:
self.miner_hotkey_to_last_touched_dict[hotkey] = now
bt.logging.trace(f"Touched Miner {hotkey}. datetime={now}.")
def touch_all_miner_models(self) -> None:
"""Touch all miner models."""
now = datetime.datetime.now()
with self.lock:
for hotkey in list(self.miner_hotkey_to_model_metadata_dict.keys()):
self.miner_hotkey_to_last_touched_dict[hotkey] = now
bt.logging.trace(f"Touched All Miners. datetime={now}.")
def update_model_hash(self, hotkey: str, new_model_hash: str) -> bool:
"""
Update the model_hash for a given hotkey.
Args:
hotkey (str): The miner's hotkey.
new_model_hash (str): The new model hash to be set.
Returns:
bool: True if the update was successful, False if the hotkey was not found.
"""
with self.lock:
self.miner_hotkey_to_model_hash_dict[hotkey] = new_model_hash
return True
def calculate_file_hash(self, file_path: str) -> str:
"""Calculate SHA1 hash of a file."""
sha1 = hashlib.sha1()
with open(file_path, 'rb') as f:
while True:
data = f.read(65536) # Read in 64kb chunks
if not data:
break
sha1.update(data)
return sha1.hexdigest()
def is_model_unique(self, hotkey_to_check: str, block_to_check: int, model_checkpoint_path: str) -> bool:
"""Check if a model with a given model_hash is already in use."""
# generate hash from model_checkpoint_path
hash_to_check = self.calculate_file_hash(model_checkpoint_path)
with self.lock:
for hotkey, metadata in self.miner_hotkey_to_model_metadata_dict.items():
if hotkey == hotkey_to_check or hotkey not in self.miner_hotkey_to_model_hash_dict:
continue
if self.miner_hotkey_to_model_hash_dict[hotkey] == hash_to_check and metadata.block < block_to_check:
bt.logging.warning(
f"*** Model with hash {hash_to_check} on block {block_to_check} is not unique. Already in use by {hotkey} on block {metadata.block} for model {metadata.id.namespace}/{metadata.id.name}. ***"
)
# Update the model hash for the hotkey
self.update_model_hash(hotkey_to_check, hash_to_check)
return False, hash_to_check
# Update the model hash for the hotkey
self.update_model_hash(hotkey_to_check, hash_to_check)
return True, hash_to_check