| | import os |
| | import dataclasses |
| | from typing import Optional, Tuple |
| |
|
| | @dataclasses.dataclass |
| | class ModelConfig: |
| | audio_model_id: str = "openai/whisper-medium" |
| | text_model_id: str = "sarvamai/sarvam-m" |
| | hidden_size: int = 2048 |
| | projector_act: str = "gelu" |
| | stack_factor: int = 8 |
| |
|
| | def to_dict(self): |
| | return dataclasses.asdict(self) |
| |
|
| | @dataclasses.dataclass |
| | class TrainConfig: |
| | |
| | batch_size: int = 8 |
| | accum_steps: int = 2 |
| | use_bf16: bool = True |
| | gradient_checkpointing: bool = False |
| | dataloader_num_workers: int = 8 |
| | dataloader_pin_memory: bool = True |
| |
|
| | learning_rate: float = 1e-4 |
| | lr_scheduler_type: str = "cosine" |
| | num_epochs: int = 1 |
| | max_steps: int = 10000 |
| |
|
| | |
| | output_dir: str = "./checkpoints" |
| | |
| | dataset_name: str = "fixie-ai/common_voice_17_0" |
| | dataset_subset: str = "hi" |
| | dataset_split: str = "train" |
| | val_dataset_split: str = "validation" |
| | |
| | |
| | use_lora: bool = True |
| | lora_r: int = 16 |
| | lora_alpha: int = 32 |
| | lora_dropout: float = 0.05 |
| | |
| | |
| | push_to_hub: bool = False |
| | hub_model_id: Optional[str] = os.getenv("HUB_MODEL_ID", None) |
| | hub_token: Optional[str] = os.getenv("HUB_TOKEN", None) |
| | hub_private_repo: bool = True |
| |
|
| | |
| | wandb_project: str = os.getenv("WANDB_PROJECT", "audio-language-model") |
| | wandb_entity: Optional[str] = os.getenv("WANDB_ENTITY", None) |
| | wandb_run_name: Optional[str] = None |
| | wandb_watch: str = "false" |
| | wandb_log_model: str = "false" |
| |
|
| | |
| | seed: int = 42 |
| | log_steps: int = 10 |
| | eval_steps: int = 250 |
| | save_steps: int = 500 |
| | save_total_limit: int = 1 |
| | sample_pred_every_steps: int = 250 |
| |
|