File size: 8,195 Bytes
bbef64d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import os
import cv2
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pytorchvideo.models.resnet import create_resnet
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# -------------------------------
# Custom Dataset for AirLetters
# -------------------------------
class AirLettersDataset(Dataset):
    def __init__(self, csv_path, video_dir, num_frames=8, image_size=224):
        self.df = pd.read_csv(csv_path)
        self.df.columns = self.df.columns.str.strip()
        self.video_dir = video_dir
        self.num_frames = num_frames
        self.image_size = image_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((image_size, image_size)),
            transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        for _ in range(10):
            row = self.df.iloc[idx]
            video_path = os.path.join(self.video_dir, row['filename'])
            frames = self._load_video(video_path)
            if frames is not None:
                label = self._label_to_id(row['label'])
                return frames, label
            idx = np.random.randint(0, len(self.df))
        raise RuntimeError("Too many unreadable videos in dataset.")

    def _label_to_id(self, label_text):
        label_text = label_text.lower()
        if "letter" in label_text:
            char = label_text.split("letter")[-1].strip().split()[0]
            return ord(char.upper()) - ord('A')
        elif "digit" in label_text:
            digit = label_text.split("digit")[-1].strip().split()[0]
            return 26 + int(digit)
        else:
            return 36

    def _load_video(self, video_path):
        try:
            cap = cv2.VideoCapture(video_path)
            total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            if total == 0 or not cap.isOpened():
                raise ValueError("Unreadable video")

            frames = []
            step = max(1, total // self.num_frames)

            for i in range(self.num_frames):
                cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
                ret, frame = cap.read()
                if not ret or frame is None:
                    continue
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = self.transform(frame)
                frames.append(frame)

            cap.release()

            if len(frames) == 0:
                raise ValueError("No valid frames")

            while len(frames) < self.num_frames:
                frames.append(torch.zeros_like(frames[0]))

            return torch.stack(frames).permute(1, 0, 2, 3)

        except Exception as e:
            print(f"[WARNING] Skipping unreadable video: {video_path} ({str(e)})")
            return None


# -------------------------------
# Train + Evaluate Function
# -------------------------------
CHECKPOINT_PATH = "checkpoint.pth"
SAVE_INTERVAL = 10000

def train(model, train_loader, val_loader, test_loader, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # ===== Resume variables =====
    start_epoch = 0
    global_step = 0
    resume_batch_idx = 0

    # ===== Load checkpoint if exists =====
    if os.path.exists(CHECKPOINT_PATH):
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        start_epoch = checkpoint['epoch']
        global_step = checkpoint['step']
        resume_batch_idx = checkpoint['batch_idx']

        print(f"πŸ” Resuming from Epoch {start_epoch}, Batch {resume_batch_idx}, Step {global_step}")

    for epoch in range(start_epoch, 5):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/5")

        for batch_idx, (inputs, labels) in loop:

            # Skip already-trained batches only on resume epoch
            if epoch == start_epoch and batch_idx < resume_batch_idx:
                continue

            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            global_step += 1

            # Save checkpoint every 10,000 steps
            if global_step % SAVE_INTERVAL == 0:
                torch.save({
                    'epoch': epoch,
                    'step': global_step,
                    'batch_idx': batch_idx,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, CHECKPOINT_PATH)

                print(f"\nπŸ’Ύ Checkpoint saved at step {global_step}")

        # Reset after first resumed epoch
        resume_batch_idx = 0

        train_acc = 100. * correct / total
        print(f"\nβœ… Epoch {epoch+1} - Loss: {running_loss/len(train_loader):.4f}, Train Accuracy: {train_acc:.2f}%")

        # Save checkpoint at end of epoch
        torch.save({
            'epoch': epoch + 1,
            'step': global_step,
            'batch_idx': 0,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, CHECKPOINT_PATH)

        # βœ… Run validation after each epoch
        model.eval()
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_acc = 100. * val_correct / val_total
        print(f"βœ… Validation Accuracy: {val_acc:.2f}%")

    # βœ… Final Test Accuracy
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()

    test_acc = 100. * test_correct / test_total
    print(f"🎯 Final Test Accuracy: {test_acc:.2f}%")

    # βœ… Save final model
    torch.save(model.state_dict(), "resnext200_airletters.pth")
    print("\nβœ… Model saved to resnext200_airletters.pth")
    print("πŸ“¦ Please upload this file to Hugging Face to preserve it.")

# -------------------------------
# Entry Point
# -------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("πŸš€ Using device:", device)

    train_csv = "train.csv"         # Update with your path
    val_csv = "val.csv"             # Update with your path
    test_csv = "test.csv"           # Update with your path
    video_dir = "/home/mluser/dataset/dataset/videos/videos"            # Update with your path

    train_set = AirLettersDataset(train_csv, video_dir)
    val_set = AirLettersDataset(val_csv, video_dir)
    test_set = AirLettersDataset(test_csv, video_dir)

    train_loader = DataLoader(train_set, batch_size=2, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_set, batch_size=2, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_set, batch_size=2, shuffle=False, num_workers=2)

    model = create_resnet(
        input_channel=3,
        model_num_class=37,
        model_depth=101,
        norm=nn.BatchNorm3d,
        activation=nn.ReLU
).to(device)
    train(model, train_loader, val_loader, test_loader, device)