from .noise_encoders import get_noise_encoder from .text_encoders import get_text_encoder from .model import ScorePredictor NOISE_ENCODERS = ['residualconv'] TEXT_ENCODERS = ['attnpool', 'lightattnpool', 'pertokenscalar'] def get_model( noise_enc: str = 'residualconv', text_enc: str = 'attnpool', dropout: float = 0.1, num_heads: int = 1, spatial_size: int = 128, in_channels: int = 4, embed_dim: int = 2048, seq_len: int = 77, pos_encoding: str = 'none', ) -> ScorePredictor: if noise_enc not in NOISE_ENCODERS: raise ValueError(f"Unknown noise encoder: {noise_enc}. Available: {NOISE_ENCODERS}") if text_enc not in TEXT_ENCODERS: raise ValueError(f"Unknown text encoder: {text_enc}. Available: {TEXT_ENCODERS}") text_encoder = get_text_encoder(text_enc, embed_dim=embed_dim, seq_len=seq_len, pos_encoding=pos_encoding) noise_encoder = get_noise_encoder(spatial_size=spatial_size, in_channels=in_channels) return ScorePredictor( noise_encoder=noise_encoder, text_encoder=text_encoder, dropout=dropout, num_heads=num_heads, ) __all__ = [ 'get_model', 'ScorePredictor', 'get_text_encoder', 'get_noise_encoder', 'NOISE_ENCODERS', 'TEXT_ENCODERS', ]