Document_Forgery_Detection / scripts /train_chunked.py
JKrishnanandhaa's picture
Upload 8 files
51fdac5 verified
"""
Chunked Training Script for Document Forgery Detection
Supports training on large datasets (DocTamper) in chunks to manage RAM constraints.
Usage:
python scripts/train_chunked.py --dataset doctamper --chunk 1
python scripts/train_chunked.py --dataset rtm
python scripts/train_chunked.py --dataset casia
python scripts/train_chunked.py --dataset receipts
"""
import argparse
import os
import sys
from pathlib import Path
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
import gc
from src.config import get_config
from src.training import get_trainer
from src.utils import plot_training_curves, plot_chunked_training_progress, generate_training_report
def parse_args():
parser = argparse.ArgumentParser(description="Train forgery detection model")
parser.add_argument('--dataset', type=str, default='doctamper',
choices=['doctamper', 'rtm', 'casia', 'receipts', 'fcd', 'scd'],
help='Dataset to train on')
parser.add_argument('--chunk', type=int, default=None,
help='Chunk number (1-4) for DocTamper chunked training')
parser.add_argument('--epochs', type=int, default=None,
help='Number of epochs (overrides config)')
parser.add_argument('--resume', type=str, default=None,
help='Checkpoint to resume from')
parser.add_argument('--config', type=str, default='config.yaml',
help='Path to config file')
return parser.parse_args()
def train_chunk(config, dataset_name: str, chunk_id: int, epochs: int = None, resume: str = None):
"""Train a single chunk"""
# Calculate chunk boundaries
chunks = config.get('data.chunked_training.chunks', [])
if chunk_id > len(chunks):
raise ValueError(f"Invalid chunk ID: {chunk_id}. Max: {len(chunks)}")
chunk_config = chunks[chunk_id - 1]
chunk_start = chunk_config['start']
chunk_end = chunk_config['end']
chunk_name = chunk_config['name']
print(f"\n{'='*60}")
print(f"Training Chunk {chunk_id}: {chunk_name}")
print(f"Range: {chunk_start*100:.0f}% - {chunk_end*100:.0f}%")
print(f"{'='*60}")
# Create trainer
trainer = get_trainer(config, dataset_name)
# Resume from previous chunk if applicable
if resume:
# For chunked training, reset epoch counter to train full epochs on new data
trainer.load_checkpoint(resume, reset_epoch=True)
elif chunk_id > 1:
# Auto-resume from previous chunk
prev_checkpoint = f'{dataset_name}_chunk{chunk_id-1}_final.pth'
if (Path(config.get('outputs.checkpoints')) / prev_checkpoint).exists():
print(f"Auto-resuming from previous chunk: {prev_checkpoint}")
trainer.load_checkpoint(prev_checkpoint, reset_epoch=True)
# Train
history = trainer.train(
epochs=epochs,
chunk_start=chunk_start,
chunk_end=chunk_end,
chunk_id=chunk_id,
resume_from=None # Already loaded above
)
# Plot training curves
plot_dir = Path(config.get('outputs.plots', 'outputs/plots'))
plot_dir.mkdir(parents=True, exist_ok=True)
plot_path = plot_dir / f'{dataset_name}_chunk{chunk_id}_curves.png'
plot_training_curves(
history,
str(plot_path),
title=f"{dataset_name.upper()} Chunk {chunk_id} Training"
)
# Generate report
report_path = plot_dir / f'{dataset_name}_chunk{chunk_id}_report.txt'
generate_training_report(history, str(report_path), f"{dataset_name} Chunk {chunk_id}")
# Clear memory
del trainer
gc.collect()
torch.cuda.empty_cache()
return history
def train_full_dataset(config, dataset_name: str, epochs: int = None, resume: str = None):
"""Train on full dataset (for smaller datasets)"""
print(f"\n{'='*60}")
print(f"Training on: {dataset_name.upper()}")
print(f"{'='*60}")
# Create trainer
trainer = get_trainer(config, dataset_name)
# Load checkpoint if resuming (reset epoch counter for new dataset)
if resume:
print(f"Loading weights from: {resume}")
trainer.load_checkpoint(resume, reset_epoch=True)
print("Epoch counter reset to 0 for new dataset training")
# Train
history = trainer.train(
epochs=epochs,
chunk_id=0,
resume_from=None # Already loaded above
)
# Plot training curves
plot_dir = Path(config.get('outputs.plots', 'outputs/plots'))
plot_dir.mkdir(parents=True, exist_ok=True)
plot_path = plot_dir / f'{dataset_name}_training_curves.png'
plot_training_curves(
history,
str(plot_path),
title=f"{dataset_name.upper()} Training"
)
# Generate report
report_path = plot_dir / f'{dataset_name}_report.txt'
generate_training_report(history, str(report_path), dataset_name)
return history
def main():
args = parse_args()
# Load config
config = get_config(args.config)
print("\n" + "="*60)
print("Hybrid Document Forgery Detection - Training")
print("="*60)
print(f"Dataset: {args.dataset}")
print(f"Device: {config.get('system.device')}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print("="*60)
# DocTamper: chunked training
if args.dataset == 'doctamper' and args.chunk is not None:
history = train_chunk(
config,
args.dataset,
args.chunk,
epochs=args.epochs,
resume=args.resume
)
# DocTamper: all chunks sequentially
elif args.dataset == 'doctamper' and args.chunk is None:
print("Training DocTamper in 4 chunks...")
all_histories = []
for chunk_id in range(1, 5):
history = train_chunk(
config,
args.dataset,
chunk_id,
epochs=args.epochs,
resume=None if chunk_id == 1 else None # Auto-resume from prev chunk
)
all_histories.append(history)
# Plot combined progress
plot_dir = Path(config.get('outputs.plots', 'outputs/plots'))
combined_path = plot_dir / 'doctamper_all_chunks_progress.png'
plot_chunked_training_progress(
all_histories,
str(combined_path),
title="DocTamper Chunked Training Progress"
)
# Other datasets: full training
else:
history = train_full_dataset(
config,
args.dataset,
epochs=args.epochs,
resume=args.resume
)
print("\n" + "="*60)
print("Training Complete!")
print("="*60)
if __name__ == '__main__':
main()