rukia07's picture
Add GPU training script
aa8831b verified
"""
RT-DETR Flowchart Detection - GPU Training Script
===================================================
Fine-tunes RT-DETR R18 for single-class flowchart bounding box detection.
Model: PekingU/rtdetr_r18vd_coco_o365 → rukia07/rtdetr-flowchart-detector
Dataset: rukia07/flowchart-detection-dataset (COCO format, 2500 train / 500 val)
Requirements:
pip install transformers torch torchvision albumentations pycocotools
pip install accelerate huggingface_hub
Usage:
# Full training (recommended: GPU with >= 8GB VRAM)
python train_gpu.py
# Quick test
python train_gpu.py --epochs 1 --max_train 100 --max_val 20
Architecture: RT-DETR (Real-Time DEtection TRansformer)
- ResNet-18 backbone → HybridEncoder → TransformerDecoder
- NMS-free, end-to-end detection
- 20M params, 217 FPS on T4 GPU
- Single class: "flowchart" (class 0)
"""
import argparse
import json
import os
import torch
import numpy as np
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from transformers import (
AutoModelForObjectDetection,
AutoImageProcessor,
TrainingArguments,
Trainer,
)
from huggingface_hub import hf_hub_download, snapshot_download
import albumentations as A
# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
class COCODetectionDataset(Dataset):
"""COCO-format detection dataset for flowchart detection."""
def __init__(self, image_dir, annotation_file, processor, augment=False, max_samples=None):
self.image_dir = Path(image_dir)
self.processor = processor
self.augment = augment
with open(annotation_file) as f:
coco = json.load(f)
self.images = {img["id"]: img for img in coco["images"]}
# Build image_id -> annotations mapping
self.img_annots = {}
for ann in coco.get("annotations", []):
img_id = ann["image_id"]
if img_id not in self.img_annots:
self.img_annots[img_id] = []
self.img_annots[img_id].append(ann)
self.image_ids = list(self.images.keys())
if max_samples:
self.image_ids = self.image_ids[:max_samples]
# Augmentation pipeline
if augment:
self.transform = A.Compose([
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.3),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.3),
A.GaussNoise(p=0.2),
], bbox_params=A.BboxParams(
format="coco", label_fields=["category_ids"], min_visibility=0.3
))
else:
self.transform = None
def __len__(self):
return len(self.image_ids)
def __getitem__(self, idx):
img_id = self.image_ids[idx]
img_info = self.images[img_id]
# Load image
img_path = self.image_dir / img_info["file_name"]
image = Image.open(img_path).convert("RGB")
w, h = image.size
# Get annotations
annots = self.img_annots.get(img_id, [])
if annots:
bboxes = [a["bbox"] for a in annots] # [x, y, w, h] COCO format
categories = [a["category_id"] for a in annots]
else:
bboxes = []
categories = []
# Apply augmentation
if self.transform and bboxes:
img_np = np.array(image)
transformed = self.transform(
image=img_np, bboxes=bboxes, category_ids=categories
)
image = Image.fromarray(transformed["image"])
bboxes = transformed["bboxes"]
categories = transformed["category_ids"]
# Convert COCO [x, y, w, h] to DETR format [cx, cy, w, h] normalized
targets = {"image_id": img_id, "annotations": []}
for bbox, cat in zip(bboxes, categories):
x, y, bw, bh = bbox
targets["annotations"].append({
"bbox": [x, y, bw, bh],
"category_id": cat,
"area": bw * bh,
"iscrowd": 0,
})
# Process with RT-DETR processor
encoding = self.processor(
images=image,
annotations=targets,
return_tensors="pt",
)
pixel_values = encoding["pixel_values"].squeeze(0)
labels = encoding["labels"][0]
return {"pixel_values": pixel_values, "labels": labels}
def collate_fn(batch):
"""Custom collate for variable-length detection labels."""
pixel_values = torch.stack([item["pixel_values"] for item in batch])
labels = [item["labels"] for item in batch]
return {"pixel_values": pixel_values, "labels": labels}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=30, help="Training epochs")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size per device")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--image_size", type=int, default=640, help="Input image size")
parser.add_argument("--max_train", type=int, default=None, help="Max train samples")
parser.add_argument("--max_val", type=int, default=None, help="Max val samples")
parser.add_argument("--output_dir", type=str, default="./rtdetr-flowchart-output")
parser.add_argument("--hub_model_id", type=str, default="rukia07/rtdetr-flowchart-detector")
parser.add_argument("--base_model", type=str, default="PekingU/rtdetr_r18vd_coco_o365")
parser.add_argument("--dataset_id", type=str, default="rukia07/flowchart-detection-dataset")
args = parser.parse_args()
print(f"{'='*60}")
print(f"RT-DETR Flowchart Detection Training")
print(f"{'='*60}")
print(f"Base model: {args.base_model}")
print(f"Dataset: {args.dataset_id}")
print(f"Image size: {args.image_size}x{args.image_size}")
print(f"Batch size: {args.batch_size}")
print(f"Epochs: {args.epochs}")
print(f"LR: {args.lr}")
print(f"Output: {args.hub_model_id}")
print(f"{'='*60}")
# Download dataset from Hub
print("\nDownloading dataset...")
dataset_dir = snapshot_download(
repo_id=args.dataset_id, repo_type="dataset",
local_dir="./flowchart_dataset"
)
# Load processor and model
print("Loading model...")
processor = AutoImageProcessor.from_pretrained(
args.base_model,
size={"height": args.image_size, "width": args.image_size},
)
model = AutoModelForObjectDetection.from_pretrained(
args.base_model,
num_labels=1,
id2label={0: "flowchart"},
label2id={"flowchart": 0},
ignore_mismatched_sizes=True,
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Create datasets
print("Loading datasets...")
train_ds = COCODetectionDataset(
image_dir=os.path.join(dataset_dir, "train", "images"),
annotation_file=os.path.join(dataset_dir, "train", "annotations.json"),
processor=processor,
augment=True,
max_samples=args.max_train,
)
val_ds = COCODetectionDataset(
image_dir=os.path.join(dataset_dir, "val", "images"),
annotation_file=os.path.join(dataset_dir, "val", "annotations.json"),
processor=processor,
augment=False,
max_samples=args.max_val,
)
print(f"Train: {len(train_ds)} images, Val: {len(val_ds)} images")
# Training arguments
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
learning_rate=args.lr,
weight_decay=0.01,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
max_grad_norm=0.1,
fp16=torch.cuda.is_available(),
dataloader_num_workers=4,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_map",
greater_is_better=True,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
remove_unused_columns=False,
eval_do_concat_batches=False,
push_to_hub=True,
hub_model_id=args.hub_model_id,
report_to="none",
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=collate_fn,
processing_class=processor,
)
# Train
print("\nStarting training...")
train_result = trainer.train()
# Evaluate
print("\nFinal evaluation...")
metrics = trainer.evaluate()
print(f"Final metrics: {metrics}")
# Push to Hub
print(f"\nPushing to {args.hub_model_id}...")
trainer.push_to_hub(commit_message="Training complete")
print(f"\n{'='*60}")
print(f"Training complete!")
print(f"Model: https://huggingface.co/{args.hub_model_id}")
print(f"{'='*60}")
if __name__ == "__main__":
main()