File size: 1,889 Bytes
3663dd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from model import Model
from dataloader import ShorthandGenerationDataset, data_split
from config import CONFIG
from tqdm import tqdm  # Import tqdm for progress bar

def collate_fn(batch):
    # print(batch) # debugging
    imgs, labels, additional = zip(*batch)
    imgs = pad_sequence(imgs, batch_first=True, padding_value=0)
    labels = pad_sequence(labels, batch_first=True, padding_value=0)
    additional = torch.stack(additional)
    return imgs, labels, additional

# Split the data
train_files, val_files, test_files, max_H, max_W, max_seq_length = data_split()

# Initialize dataset and dataloaders
train_dataset = ShorthandGenerationDataset(train_files, max_H, max_W, aug_types=9, max_label_leng=max_seq_length, channels=1)
val_dataset = ShorthandGenerationDataset(val_files, max_H, max_W, aug_types=1, max_label_leng=max_seq_length, channels=1)
test_dataset = ShorthandGenerationDataset(test_files, max_H, max_W, aug_types=1, max_label_leng=max_seq_length, channels=1)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# Initialize model
config = CONFIG()
model = Model(max_H, max_W, config)

# Example training loop
num_epochs = 10  # Define the number of epochs
for epoch in range(num_epochs):
    model.train()
    for imgs, labels, additional in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        # Training step
        pass

    model.eval()
    with torch.no_grad():
        for imgs, labels, additional in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}"):
            # Validation step
            pass