rukia07 commited on
Commit
aa8831b
·
verified ·
1 Parent(s): 3ab23d1

Add GPU training script

Browse files
Files changed (1) hide show
  1. train_gpu.py +283 -0
train_gpu.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RT-DETR Flowchart Detection - GPU Training Script
3
+ ===================================================
4
+ Fine-tunes RT-DETR R18 for single-class flowchart bounding box detection.
5
+
6
+ Model: PekingU/rtdetr_r18vd_coco_o365 → rukia07/rtdetr-flowchart-detector
7
+ Dataset: rukia07/flowchart-detection-dataset (COCO format, 2500 train / 500 val)
8
+
9
+ Requirements:
10
+ pip install transformers torch torchvision albumentations pycocotools
11
+ pip install accelerate huggingface_hub
12
+
13
+ Usage:
14
+ # Full training (recommended: GPU with >= 8GB VRAM)
15
+ python train_gpu.py
16
+
17
+ # Quick test
18
+ python train_gpu.py --epochs 1 --max_train 100 --max_val 20
19
+
20
+ Architecture: RT-DETR (Real-Time DEtection TRansformer)
21
+ - ResNet-18 backbone → HybridEncoder → TransformerDecoder
22
+ - NMS-free, end-to-end detection
23
+ - 20M params, 217 FPS on T4 GPU
24
+ - Single class: "flowchart" (class 0)
25
+ """
26
+
27
+ import argparse
28
+ import json
29
+ import os
30
+ import torch
31
+ import numpy as np
32
+ from pathlib import Path
33
+ from PIL import Image
34
+ from torch.utils.data import Dataset
35
+ from transformers import (
36
+ AutoModelForObjectDetection,
37
+ AutoImageProcessor,
38
+ TrainingArguments,
39
+ Trainer,
40
+ )
41
+ from huggingface_hub import hf_hub_download, snapshot_download
42
+ import albumentations as A
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Dataset
46
+ # ---------------------------------------------------------------------------
47
+
48
+ class COCODetectionDataset(Dataset):
49
+ """COCO-format detection dataset for flowchart detection."""
50
+
51
+ def __init__(self, image_dir, annotation_file, processor, augment=False, max_samples=None):
52
+ self.image_dir = Path(image_dir)
53
+ self.processor = processor
54
+ self.augment = augment
55
+
56
+ with open(annotation_file) as f:
57
+ coco = json.load(f)
58
+
59
+ self.images = {img["id"]: img for img in coco["images"]}
60
+
61
+ # Build image_id -> annotations mapping
62
+ self.img_annots = {}
63
+ for ann in coco.get("annotations", []):
64
+ img_id = ann["image_id"]
65
+ if img_id not in self.img_annots:
66
+ self.img_annots[img_id] = []
67
+ self.img_annots[img_id].append(ann)
68
+
69
+ self.image_ids = list(self.images.keys())
70
+ if max_samples:
71
+ self.image_ids = self.image_ids[:max_samples]
72
+
73
+ # Augmentation pipeline
74
+ if augment:
75
+ self.transform = A.Compose([
76
+ A.HorizontalFlip(p=0.5),
77
+ A.RandomBrightnessContrast(p=0.3),
78
+ A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.3),
79
+ A.GaussNoise(p=0.2),
80
+ ], bbox_params=A.BboxParams(
81
+ format="coco", label_fields=["category_ids"], min_visibility=0.3
82
+ ))
83
+ else:
84
+ self.transform = None
85
+
86
+ def __len__(self):
87
+ return len(self.image_ids)
88
+
89
+ def __getitem__(self, idx):
90
+ img_id = self.image_ids[idx]
91
+ img_info = self.images[img_id]
92
+
93
+ # Load image
94
+ img_path = self.image_dir / img_info["file_name"]
95
+ image = Image.open(img_path).convert("RGB")
96
+ w, h = image.size
97
+
98
+ # Get annotations
99
+ annots = self.img_annots.get(img_id, [])
100
+
101
+ if annots:
102
+ bboxes = [a["bbox"] for a in annots] # [x, y, w, h] COCO format
103
+ categories = [a["category_id"] for a in annots]
104
+ else:
105
+ bboxes = []
106
+ categories = []
107
+
108
+ # Apply augmentation
109
+ if self.transform and bboxes:
110
+ img_np = np.array(image)
111
+ transformed = self.transform(
112
+ image=img_np, bboxes=bboxes, category_ids=categories
113
+ )
114
+ image = Image.fromarray(transformed["image"])
115
+ bboxes = transformed["bboxes"]
116
+ categories = transformed["category_ids"]
117
+
118
+ # Convert COCO [x, y, w, h] to DETR format [cx, cy, w, h] normalized
119
+ targets = {"image_id": img_id, "annotations": []}
120
+ for bbox, cat in zip(bboxes, categories):
121
+ x, y, bw, bh = bbox
122
+ targets["annotations"].append({
123
+ "bbox": [x, y, bw, bh],
124
+ "category_id": cat,
125
+ "area": bw * bh,
126
+ "iscrowd": 0,
127
+ })
128
+
129
+ # Process with RT-DETR processor
130
+ encoding = self.processor(
131
+ images=image,
132
+ annotations=targets,
133
+ return_tensors="pt",
134
+ )
135
+
136
+ pixel_values = encoding["pixel_values"].squeeze(0)
137
+ labels = encoding["labels"][0]
138
+
139
+ return {"pixel_values": pixel_values, "labels": labels}
140
+
141
+
142
+ def collate_fn(batch):
143
+ """Custom collate for variable-length detection labels."""
144
+ pixel_values = torch.stack([item["pixel_values"] for item in batch])
145
+ labels = [item["labels"] for item in batch]
146
+ return {"pixel_values": pixel_values, "labels": labels}
147
+
148
+
149
+ # ---------------------------------------------------------------------------
150
+ # Main
151
+ # ---------------------------------------------------------------------------
152
+
153
+ def main():
154
+ parser = argparse.ArgumentParser()
155
+ parser.add_argument("--epochs", type=int, default=30, help="Training epochs")
156
+ parser.add_argument("--batch_size", type=int, default=8, help="Batch size per device")
157
+ parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
158
+ parser.add_argument("--image_size", type=int, default=640, help="Input image size")
159
+ parser.add_argument("--max_train", type=int, default=None, help="Max train samples")
160
+ parser.add_argument("--max_val", type=int, default=None, help="Max val samples")
161
+ parser.add_argument("--output_dir", type=str, default="./rtdetr-flowchart-output")
162
+ parser.add_argument("--hub_model_id", type=str, default="rukia07/rtdetr-flowchart-detector")
163
+ parser.add_argument("--base_model", type=str, default="PekingU/rtdetr_r18vd_coco_o365")
164
+ parser.add_argument("--dataset_id", type=str, default="rukia07/flowchart-detection-dataset")
165
+ args = parser.parse_args()
166
+
167
+ print(f"{'='*60}")
168
+ print(f"RT-DETR Flowchart Detection Training")
169
+ print(f"{'='*60}")
170
+ print(f"Base model: {args.base_model}")
171
+ print(f"Dataset: {args.dataset_id}")
172
+ print(f"Image size: {args.image_size}x{args.image_size}")
173
+ print(f"Batch size: {args.batch_size}")
174
+ print(f"Epochs: {args.epochs}")
175
+ print(f"LR: {args.lr}")
176
+ print(f"Output: {args.hub_model_id}")
177
+ print(f"{'='*60}")
178
+
179
+ # Download dataset from Hub
180
+ print("\nDownloading dataset...")
181
+ dataset_dir = snapshot_download(
182
+ repo_id=args.dataset_id, repo_type="dataset",
183
+ local_dir="./flowchart_dataset"
184
+ )
185
+
186
+ # Load processor and model
187
+ print("Loading model...")
188
+ processor = AutoImageProcessor.from_pretrained(
189
+ args.base_model,
190
+ size={"height": args.image_size, "width": args.image_size},
191
+ )
192
+
193
+ model = AutoModelForObjectDetection.from_pretrained(
194
+ args.base_model,
195
+ num_labels=1,
196
+ id2label={0: "flowchart"},
197
+ label2id={"flowchart": 0},
198
+ ignore_mismatched_sizes=True,
199
+ )
200
+
201
+ print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
202
+
203
+ # Create datasets
204
+ print("Loading datasets...")
205
+ train_ds = COCODetectionDataset(
206
+ image_dir=os.path.join(dataset_dir, "train", "images"),
207
+ annotation_file=os.path.join(dataset_dir, "train", "annotations.json"),
208
+ processor=processor,
209
+ augment=True,
210
+ max_samples=args.max_train,
211
+ )
212
+
213
+ val_ds = COCODetectionDataset(
214
+ image_dir=os.path.join(dataset_dir, "val", "images"),
215
+ annotation_file=os.path.join(dataset_dir, "val", "annotations.json"),
216
+ processor=processor,
217
+ augment=False,
218
+ max_samples=args.max_val,
219
+ )
220
+
221
+ print(f"Train: {len(train_ds)} images, Val: {len(val_ds)} images")
222
+
223
+ # Training arguments
224
+ training_args = TrainingArguments(
225
+ output_dir=args.output_dir,
226
+ num_train_epochs=args.epochs,
227
+ per_device_train_batch_size=args.batch_size,
228
+ per_device_eval_batch_size=args.batch_size,
229
+ learning_rate=args.lr,
230
+ weight_decay=0.01,
231
+ lr_scheduler_type="cosine",
232
+ warmup_ratio=0.1,
233
+ max_grad_norm=0.1,
234
+ fp16=torch.cuda.is_available(),
235
+ dataloader_num_workers=4,
236
+ eval_strategy="epoch",
237
+ save_strategy="epoch",
238
+ save_total_limit=3,
239
+ load_best_model_at_end=True,
240
+ metric_for_best_model="eval_map",
241
+ greater_is_better=True,
242
+ logging_strategy="steps",
243
+ logging_steps=10,
244
+ logging_first_step=True,
245
+ disable_tqdm=True,
246
+ remove_unused_columns=False,
247
+ eval_do_concat_batches=False,
248
+ push_to_hub=True,
249
+ hub_model_id=args.hub_model_id,
250
+ report_to="none",
251
+ )
252
+
253
+ # Trainer
254
+ trainer = Trainer(
255
+ model=model,
256
+ args=training_args,
257
+ train_dataset=train_ds,
258
+ eval_dataset=val_ds,
259
+ data_collator=collate_fn,
260
+ processing_class=processor,
261
+ )
262
+
263
+ # Train
264
+ print("\nStarting training...")
265
+ train_result = trainer.train()
266
+
267
+ # Evaluate
268
+ print("\nFinal evaluation...")
269
+ metrics = trainer.evaluate()
270
+ print(f"Final metrics: {metrics}")
271
+
272
+ # Push to Hub
273
+ print(f"\nPushing to {args.hub_model_id}...")
274
+ trainer.push_to_hub(commit_message="Training complete")
275
+
276
+ print(f"\n{'='*60}")
277
+ print(f"Training complete!")
278
+ print(f"Model: https://huggingface.co/{args.hub_model_id}")
279
+ print(f"{'='*60}")
280
+
281
+
282
+ if __name__ == "__main__":
283
+ main()