import os import cv2 import torch import numpy as np import pandas as pd from torch.utils.data import Dataset, DataLoader from torchvision import transforms from pytorchvideo.models.resnet import create_resnet import torch.nn as nn import torch.optim as optim from tqdm import tqdm # ------------------------------- # Custom Dataset for AirLetters # ------------------------------- class AirLettersDataset(Dataset): def __init__(self, csv_path, video_dir, num_frames=8, image_size=224): self.df = pd.read_csv(csv_path) self.df.columns = self.df.columns.str.strip() self.video_dir = video_dir self.num_frames = num_frames self.image_size = image_size self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((image_size, image_size)), transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]) ]) def __len__(self): return len(self.df) def __getitem__(self, idx): for _ in range(10): row = self.df.iloc[idx] video_path = os.path.join(self.video_dir, row['filename']) frames = self._load_video(video_path) if frames is not None: label = self._label_to_id(row['label']) return frames, label idx = np.random.randint(0, len(self.df)) raise RuntimeError("Too many unreadable videos in dataset.") def _label_to_id(self, label_text): label_text = label_text.lower() if "letter" in label_text: char = label_text.split("letter")[-1].strip().split()[0] return ord(char.upper()) - ord('A') elif "digit" in label_text: digit = label_text.split("digit")[-1].strip().split()[0] return 26 + int(digit) else: return 36 def _load_video(self, video_path): try: cap = cv2.VideoCapture(video_path) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total == 0 or not cap.isOpened(): raise ValueError("Unreadable video") frames = [] step = max(1, total // self.num_frames) for i in range(self.num_frames): cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) ret, frame = cap.read() if not ret or frame is None: continue frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = self.transform(frame) frames.append(frame) cap.release() if len(frames) == 0: raise ValueError("No valid frames") while len(frames) < self.num_frames: frames.append(torch.zeros_like(frames[0])) return torch.stack(frames).permute(1, 0, 2, 3) except Exception as e: print(f"[WARNING] Skipping unreadable video: {video_path} ({str(e)})") return None # ------------------------------- # Train + Evaluate Function # ------------------------------- CHECKPOINT_PATH = "checkpoint.pth" SAVE_INTERVAL = 10000 def train(model, train_loader, val_loader, test_loader, device): criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) # ===== Resume variables ===== start_epoch = 0 global_step = 0 resume_batch_idx = 0 # ===== Load checkpoint if exists ===== if os.path.exists(CHECKPOINT_PATH): checkpoint = torch.load(CHECKPOINT_PATH, map_location=device) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] global_step = checkpoint['step'] resume_batch_idx = checkpoint['batch_idx'] print(f"šŸ” Resuming from Epoch {start_epoch}, Batch {resume_batch_idx}, Step {global_step}") for epoch in range(start_epoch, 5): model.train() running_loss = 0.0 correct = 0 total = 0 loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/5") for batch_idx, (inputs, labels) in loop: # Skip already-trained batches only on resume epoch if epoch == start_epoch and batch_idx < resume_batch_idx: continue inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() global_step += 1 # Save checkpoint every 10,000 steps if global_step % SAVE_INTERVAL == 0: torch.save({ 'epoch': epoch, 'step': global_step, 'batch_idx': batch_idx, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, CHECKPOINT_PATH) print(f"\nšŸ’¾ Checkpoint saved at step {global_step}") # Reset after first resumed epoch resume_batch_idx = 0 train_acc = 100. * correct / total print(f"\nāœ… Epoch {epoch+1} - Loss: {running_loss/len(train_loader):.4f}, Train Accuracy: {train_acc:.2f}%") # Save checkpoint at end of epoch torch.save({ 'epoch': epoch + 1, 'step': global_step, 'batch_idx': 0, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, CHECKPOINT_PATH) # āœ… Run validation after each epoch model.eval() val_correct = 0 val_total = 0 with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = outputs.max(1) val_total += labels.size(0) val_correct += predicted.eq(labels).sum().item() val_acc = 100. * val_correct / val_total print(f"āœ… Validation Accuracy: {val_acc:.2f}%") # āœ… Final Test Accuracy test_correct = 0 test_total = 0 with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = outputs.max(1) test_total += labels.size(0) test_correct += predicted.eq(labels).sum().item() test_acc = 100. * test_correct / test_total print(f"šŸŽÆ Final Test Accuracy: {test_acc:.2f}%") # āœ… Save final model torch.save(model.state_dict(), "resnext200_airletters.pth") print("\nāœ… Model saved to resnext200_airletters.pth") print("šŸ“¦ Please upload this file to Hugging Face to preserve it.") # ------------------------------- # Entry Point # ------------------------------- if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("šŸš€ Using device:", device) train_csv = "train.csv" # Update with your path val_csv = "val.csv" # Update with your path test_csv = "test.csv" # Update with your path video_dir = "/home/mluser/dataset/dataset/videos/videos" # Update with your path train_set = AirLettersDataset(train_csv, video_dir) val_set = AirLettersDataset(val_csv, video_dir) test_set = AirLettersDataset(test_csv, video_dir) train_loader = DataLoader(train_set, batch_size=2, shuffle=True, num_workers=2) val_loader = DataLoader(val_set, batch_size=2, shuffle=False, num_workers=2) test_loader = DataLoader(test_set, batch_size=2, shuffle=False, num_workers=2) model = create_resnet( input_channel=3, model_num_class=37, model_depth=101, norm=nn.BatchNorm3d, activation=nn.ReLU ).to(device) train(model, train_loader, val_loader, test_loader, device)