|
|
from typing import Any, ClassVar, Dict, Optional, Type |
|
|
|
|
|
from pydantic import BaseModel, Field, PositiveInt, ConfigDict |
|
|
|
|
|
|
|
|
MAX_METADATA_BYTES = 128 |
|
|
|
|
|
GIT_COMMIT_LENGTH = 40 |
|
|
|
|
|
SHA256_BASE_64_LENGTH = 44 |
|
|
|
|
|
MAX_COMPETITION_ID_LENGTH = 2 |
|
|
|
|
|
|
|
|
class ModelId(BaseModel): |
|
|
"""Uniquely identifies a trained model""" |
|
|
|
|
|
MAX_REPO_ID_LENGTH: ClassVar[int] = ( |
|
|
MAX_METADATA_BYTES |
|
|
- GIT_COMMIT_LENGTH |
|
|
- SHA256_BASE_64_LENGTH |
|
|
- MAX_COMPETITION_ID_LENGTH |
|
|
- 4 |
|
|
) |
|
|
|
|
|
namespace: str = Field( |
|
|
description="Namespace where the model can be found. ex. Hugging Face username/org." |
|
|
) |
|
|
name: str = Field(description="Name of the model.") |
|
|
|
|
|
epoch: str = Field(description="The epoch number to submit as your checkpoint to evaluate e.g. 10") |
|
|
|
|
|
|
|
|
|
|
|
commit: Optional[str] = Field( |
|
|
description="Commit of the model. May be empty if not yet committed." |
|
|
) |
|
|
|
|
|
hash: Optional[str] = Field(description="Hash of the trained model.") |
|
|
|
|
|
competition_id: Optional[str] = Field(description="The competition id") |
|
|
|
|
|
def to_compressed_str(self) -> str: |
|
|
"""Returns a compressed string representation.""" |
|
|
return f"{self.namespace}:{self.name}:{self.epoch}:{self.commit}:{self.hash}:{self.competition_id}" |
|
|
|
|
|
@classmethod |
|
|
def from_compressed_str(cls, cs: str) -> Type["ModelId"]: |
|
|
"""Returns an instance of this class from a compressed string representation""" |
|
|
tokens = cs.split(":") |
|
|
return cls( |
|
|
namespace=tokens[0], |
|
|
name=tokens[1], |
|
|
epoch=tokens[2] if tokens[2] != "None" else None, |
|
|
commit=tokens[3] if tokens[3] != "None" else None, |
|
|
hash=tokens[4] if tokens[4] != "None" else None, |
|
|
competition_id=( |
|
|
tokens[5] if len(tokens) >= 6 and tokens[5] != "None" else None |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
class Model(BaseModel): |
|
|
"""Represents a pre trained foundation model.""" |
|
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True) |
|
|
|
|
|
id: ModelId = Field(description="Identifier for this model.") |
|
|
local_repo_dir: str = Field(description="Local repository with the required files.") |
|
|
|
|
|
|
|
|
class ModelMetadata(BaseModel): |
|
|
id: ModelId = Field(description="Identifier for this trained model.") |
|
|
block: PositiveInt = Field( |
|
|
description="Block on which this model was claimed on the chain." |
|
|
) |
|
|
|