DIMPU1516 commited on
Commit
bbef64d
·
verified ·
1 Parent(s): 272752e

Upload train_2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_2.py +237 -0
train_2.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import pandas as pd
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torchvision import transforms
8
+ from pytorchvideo.models.resnet import create_resnet
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from tqdm import tqdm
12
+
13
+ # -------------------------------
14
+ # Custom Dataset for AirLetters
15
+ # -------------------------------
16
+ class AirLettersDataset(Dataset):
17
+ def __init__(self, csv_path, video_dir, num_frames=8, image_size=224):
18
+ self.df = pd.read_csv(csv_path)
19
+ self.df.columns = self.df.columns.str.strip()
20
+ self.video_dir = video_dir
21
+ self.num_frames = num_frames
22
+ self.image_size = image_size
23
+ self.transform = transforms.Compose([
24
+ transforms.ToTensor(),
25
+ transforms.Resize((image_size, image_size)),
26
+ transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])
27
+ ])
28
+
29
+ def __len__(self):
30
+ return len(self.df)
31
+
32
+ def __getitem__(self, idx):
33
+ for _ in range(10):
34
+ row = self.df.iloc[idx]
35
+ video_path = os.path.join(self.video_dir, row['filename'])
36
+ frames = self._load_video(video_path)
37
+ if frames is not None:
38
+ label = self._label_to_id(row['label'])
39
+ return frames, label
40
+ idx = np.random.randint(0, len(self.df))
41
+ raise RuntimeError("Too many unreadable videos in dataset.")
42
+
43
+ def _label_to_id(self, label_text):
44
+ label_text = label_text.lower()
45
+ if "letter" in label_text:
46
+ char = label_text.split("letter")[-1].strip().split()[0]
47
+ return ord(char.upper()) - ord('A')
48
+ elif "digit" in label_text:
49
+ digit = label_text.split("digit")[-1].strip().split()[0]
50
+ return 26 + int(digit)
51
+ else:
52
+ return 36
53
+
54
+ def _load_video(self, video_path):
55
+ try:
56
+ cap = cv2.VideoCapture(video_path)
57
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
58
+ if total == 0 or not cap.isOpened():
59
+ raise ValueError("Unreadable video")
60
+
61
+ frames = []
62
+ step = max(1, total // self.num_frames)
63
+
64
+ for i in range(self.num_frames):
65
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
66
+ ret, frame = cap.read()
67
+ if not ret or frame is None:
68
+ continue
69
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
70
+ frame = self.transform(frame)
71
+ frames.append(frame)
72
+
73
+ cap.release()
74
+
75
+ if len(frames) == 0:
76
+ raise ValueError("No valid frames")
77
+
78
+ while len(frames) < self.num_frames:
79
+ frames.append(torch.zeros_like(frames[0]))
80
+
81
+ return torch.stack(frames).permute(1, 0, 2, 3)
82
+
83
+ except Exception as e:
84
+ print(f"[WARNING] Skipping unreadable video: {video_path} ({str(e)})")
85
+ return None
86
+
87
+
88
+ # -------------------------------
89
+ # Train + Evaluate Function
90
+ # -------------------------------
91
+ CHECKPOINT_PATH = "checkpoint.pth"
92
+ SAVE_INTERVAL = 10000
93
+
94
+ def train(model, train_loader, val_loader, test_loader, device):
95
+ criterion = nn.CrossEntropyLoss()
96
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
97
+
98
+ # ===== Resume variables =====
99
+ start_epoch = 0
100
+ global_step = 0
101
+ resume_batch_idx = 0
102
+
103
+ # ===== Load checkpoint if exists =====
104
+ if os.path.exists(CHECKPOINT_PATH):
105
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
106
+
107
+ model.load_state_dict(checkpoint['model'])
108
+ optimizer.load_state_dict(checkpoint['optimizer'])
109
+
110
+ start_epoch = checkpoint['epoch']
111
+ global_step = checkpoint['step']
112
+ resume_batch_idx = checkpoint['batch_idx']
113
+
114
+ print(f"🔁 Resuming from Epoch {start_epoch}, Batch {resume_batch_idx}, Step {global_step}")
115
+
116
+ for epoch in range(start_epoch, 5):
117
+ model.train()
118
+ running_loss = 0.0
119
+ correct = 0
120
+ total = 0
121
+
122
+ loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/5")
123
+
124
+ for batch_idx, (inputs, labels) in loop:
125
+
126
+ # Skip already-trained batches only on resume epoch
127
+ if epoch == start_epoch and batch_idx < resume_batch_idx:
128
+ continue
129
+
130
+ inputs, labels = inputs.to(device), labels.to(device)
131
+
132
+ optimizer.zero_grad()
133
+ outputs = model(inputs)
134
+ loss = criterion(outputs, labels)
135
+ loss.backward()
136
+ optimizer.step()
137
+
138
+ running_loss += loss.item()
139
+
140
+ _, predicted = outputs.max(1)
141
+ total += labels.size(0)
142
+ correct += predicted.eq(labels).sum().item()
143
+
144
+ global_step += 1
145
+
146
+ # Save checkpoint every 10,000 steps
147
+ if global_step % SAVE_INTERVAL == 0:
148
+ torch.save({
149
+ 'epoch': epoch,
150
+ 'step': global_step,
151
+ 'batch_idx': batch_idx,
152
+ 'model': model.state_dict(),
153
+ 'optimizer': optimizer.state_dict()
154
+ }, CHECKPOINT_PATH)
155
+
156
+ print(f"\n💾 Checkpoint saved at step {global_step}")
157
+
158
+ # Reset after first resumed epoch
159
+ resume_batch_idx = 0
160
+
161
+ train_acc = 100. * correct / total
162
+ print(f"\n✅ Epoch {epoch+1} - Loss: {running_loss/len(train_loader):.4f}, Train Accuracy: {train_acc:.2f}%")
163
+
164
+ # Save checkpoint at end of epoch
165
+ torch.save({
166
+ 'epoch': epoch + 1,
167
+ 'step': global_step,
168
+ 'batch_idx': 0,
169
+ 'model': model.state_dict(),
170
+ 'optimizer': optimizer.state_dict()
171
+ }, CHECKPOINT_PATH)
172
+
173
+ # ✅ Run validation after each epoch
174
+ model.eval()
175
+ val_correct = 0
176
+ val_total = 0
177
+
178
+ with torch.no_grad():
179
+ for inputs, labels in val_loader:
180
+ inputs, labels = inputs.to(device), labels.to(device)
181
+ outputs = model(inputs)
182
+ _, predicted = outputs.max(1)
183
+ val_total += labels.size(0)
184
+ val_correct += predicted.eq(labels).sum().item()
185
+
186
+ val_acc = 100. * val_correct / val_total
187
+ print(f"✅ Validation Accuracy: {val_acc:.2f}%")
188
+
189
+ # ✅ Final Test Accuracy
190
+ test_correct = 0
191
+ test_total = 0
192
+
193
+ with torch.no_grad():
194
+ for inputs, labels in test_loader:
195
+ inputs, labels = inputs.to(device), labels.to(device)
196
+ outputs = model(inputs)
197
+ _, predicted = outputs.max(1)
198
+ test_total += labels.size(0)
199
+ test_correct += predicted.eq(labels).sum().item()
200
+
201
+ test_acc = 100. * test_correct / test_total
202
+ print(f"🎯 Final Test Accuracy: {test_acc:.2f}%")
203
+
204
+ # ✅ Save final model
205
+ torch.save(model.state_dict(), "resnext200_airletters.pth")
206
+ print("\n✅ Model saved to resnext200_airletters.pth")
207
+ print("📦 Please upload this file to Hugging Face to preserve it.")
208
+
209
+ # -------------------------------
210
+ # Entry Point
211
+ # -------------------------------
212
+ if __name__ == "__main__":
213
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
214
+ print("🚀 Using device:", device)
215
+
216
+ train_csv = "train.csv" # Update with your path
217
+ val_csv = "val.csv" # Update with your path
218
+ test_csv = "test.csv" # Update with your path
219
+ video_dir = "/home/mluser/dataset/dataset/videos/videos" # Update with your path
220
+
221
+ train_set = AirLettersDataset(train_csv, video_dir)
222
+ val_set = AirLettersDataset(val_csv, video_dir)
223
+ test_set = AirLettersDataset(test_csv, video_dir)
224
+
225
+ train_loader = DataLoader(train_set, batch_size=2, shuffle=True, num_workers=2)
226
+ val_loader = DataLoader(val_set, batch_size=2, shuffle=False, num_workers=2)
227
+ test_loader = DataLoader(test_set, batch_size=2, shuffle=False, num_workers=2)
228
+
229
+ model = create_resnet(
230
+ input_channel=3,
231
+ model_num_class=37,
232
+ model_depth=101,
233
+ norm=nn.BatchNorm3d,
234
+ activation=nn.ReLU
235
+ ).to(device)
236
+ train(model, train_loader, val_loader, test_loader, device)
237
+