Spaces:
Sleeping
Sleeping
File size: 1,640 Bytes
2ad4d00 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | from dataclasses import dataclass, field
from typing import Dict, List, Tuple
@dataclass
class ModelConfig:
"""Parameters defining the model architecture and basic properties."""
model_id: str = "CompVis/stable-diffusion-v1-4"
image_size: Tuple[int, int] = (240, 320)
num_timesteps: int = 100
history_len: int = 4
num_actions: int = 7
use_lora: bool = True
@dataclass
class TrainingConfig:
"""Parameters specific to the training process."""
repo_id: str = "RevanthGundala/tiny_engine" # Dataset repository
learning_rate: float = 1e-4
subset_percentage: float = 0.01
batch_size: int = 16
num_epochs: int = 1
lora_rank: int = 4 # Only used if ModelConfig.use_lora is True
lora_alpha: int = 4 # Only used if ModelConfig.use_lora is True
@dataclass
class PredictionConfig:
"""Parameters for the prediction server (app.py)."""
model_repo_id: str = "RevanthGundala/tiny_engine" # For model weights
dataset_repo_id: str = "RevanthGundala/tiny_engine" # For starting frame video
prediction_epoch: int = 99
output_dir: str = "output" # To load weights if not using MLflow
action_map: Dict[str, List[int]] = field(default_factory=lambda: {
"w": [1, 0, 0, 0, 0, 0, 0], # MOVE_FORWARD
"s": [0, 1, 0, 0, 0, 0, 0], # MOVE_BACKWARD
"d": [0, 0, 1, 0, 0, 0, 0], # MOVE_RIGHT
"a": [0, 0, 0, 1, 0, 0, 0], # MOVE_LEFT
"ArrowLeft": [0, 0, 0, 0, 1, 0, 0], # TURN_LEFT
"ArrowRight": [0, 0, 0, 0, 0, 1, 0], # TURN_RIGHT
" ": [0, 0, 0, 0, 0, 0, 1], # ATTACK
"noop": [0, 0, 0, 0, 0, 0, 0], # No operation
}) |