| from models.text2video_model import Text2VideoModel |
| from training.trainer import Text2VideoTrainer |
| from config.model_config import CONFIG |
| import torch |
|
|
| def main(): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| model = Text2VideoModel( |
| vocab_size=CONFIG['vocab_size'], |
| embed_dim=CONFIG['embed_dim'], |
| latent_dim=CONFIG['latent_dim'], |
| num_frames=CONFIG['num_frames'], |
| frame_size=CONFIG['frame_size'] |
| ).to(device) |
| |
| optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate']) |
| trainer = Text2VideoTrainer(model, optimizer, device) |
| |
| |
|
|
| if __name__ == '__main__': |
| main() |
| |
| class Text2VideoTrainer: |
| def __init__(self, model, optimizer, device): |
| self.model = model |
| self.optimizer = optimizer |
| self.device = device |
| |
| def train_step(self, text_batch, video_batch): |
| self.optimizer.zero_grad() |
| |
| generated_video = self.model(text_batch) |
| loss = F.mse_loss(generated_video, video_batch) |
| |
| loss.backward() |
| self.optimizer.step() |
| |
| return loss.item() |
|
|