File size: 2,921 Bytes
6e17fd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from typing import Any, ClassVar, Dict, Optional, Type
# from transformers import PreTrainedModel, PreTrainedTokenizerBase
from pydantic import BaseModel, Field, PositiveInt, ConfigDict
# The maximum bytes for metadata on the chain.
MAX_METADATA_BYTES = 128
# The length, in bytes, of a git commit hash.
GIT_COMMIT_LENGTH = 40
# The length, in bytes, of a base64 encoded sha256 hash.
SHA256_BASE_64_LENGTH = 44
# The max length, in characters, of the competition id
MAX_COMPETITION_ID_LENGTH = 2
class ModelId(BaseModel):
"""Uniquely identifies a trained model"""
MAX_REPO_ID_LENGTH: ClassVar[int] = (
MAX_METADATA_BYTES
- GIT_COMMIT_LENGTH
- SHA256_BASE_64_LENGTH
- MAX_COMPETITION_ID_LENGTH
- 4 # separators
)
namespace: str = Field(
description="Namespace where the model can be found. ex. Hugging Face username/org."
)
name: str = Field(description="Name of the model.")
epoch: str = Field(description="The epoch number to submit as your checkpoint to evaluate e.g. 10")
# When handling a model locally the commit and hash are not necessary.
# Commit must be filled when trying to download from a remote store.
commit: Optional[str] = Field(
description="Commit of the model. May be empty if not yet committed."
)
# Hash is filled automatically when uploading to or downloading from a remote store.
hash: Optional[str] = Field(description="Hash of the trained model.")
# Identifier for competition
competition_id: Optional[str] = Field(description="The competition id")
def to_compressed_str(self) -> str:
"""Returns a compressed string representation."""
return f"{self.namespace}:{self.name}:{self.epoch}:{self.commit}:{self.hash}:{self.competition_id}"
@classmethod
def from_compressed_str(cls, cs: str) -> Type["ModelId"]:
"""Returns an instance of this class from a compressed string representation"""
tokens = cs.split(":")
return cls(
namespace=tokens[0],
name=tokens[1],
epoch=tokens[2] if tokens[2] != "None" else None,
commit=tokens[3] if tokens[3] != "None" else None,
hash=tokens[4] if tokens[4] != "None" else None,
competition_id=(
tokens[5] if len(tokens) >= 6 and tokens[5] != "None" else None
),
)
class Model(BaseModel):
"""Represents a pre trained foundation model."""
model_config = ConfigDict(arbitrary_types_allowed=True)
id: ModelId = Field(description="Identifier for this model.")
local_repo_dir: str = Field(description="Local repository with the required files.")
class ModelMetadata(BaseModel):
id: ModelId = Field(description="Identifier for this trained model.")
block: PositiveInt = Field(
description="Block on which this model was claimed on the chain."
)
|