Spaces:
Runtime error
Runtime error
File size: 8,401 Bytes
30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 bdb70cc 30672d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import torch
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from functools import partial
from datasets import Dataset, Image, ClassLabel
from imblearn.over_sampling import RandomOverSampler
from transformers import (
EfficientFormerImageProcessor,
EfficientFormerForImageClassification,
TrainingArguments,
Trainer
)
from torchvision.transforms import (
Compose,
Normalize,
RandomRotation,
RandomHorizontalFlip,
Resize,
ToTensor
)
import evaluate
from cnnClassifier.entity.config_entity import ModelTrainerConfig
from cnnClassifier import logger
# ==============================================================================
# TOP-LEVEL FUNCTION DEFINITIONS (FOR PICKLING)
# ==============================================================================
def apply_transforms(batch, processor, transform_pipeline):
"""Applies a given transformation pipeline to a batch of images."""
# Create the normalization transform with stats from the processor
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
# Combine the base transforms with normalization
full_transforms = Compose([*transform_pipeline.transforms, normalize])
# Apply to each image in the batch
batch["pixel_values"] = [full_transforms(img.convert("RGB")) for img in batch["image"]]
return batch
def collate_fn(batch):
"""A custom collate function for image classification."""
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['label'] for x in batch])
}
def compute_metrics(eval_pred):
"""Computes accuracy metric for evaluation."""
accuracy = evaluate.load("accuracy")
predictions, label_ids = eval_pred
predicted_labels = predictions.argmax(axis=1)
return accuracy.compute(predictions=predicted_labels, references=label_ids)
# ==============================================================================
class ModelTrainer:
def __init__(self, config: ModelTrainerConfig):
self.config = config
def _prepare_data(self):
logger.info("Preparing data...")
label_dict = {'001': '01', '002': '02', '003': '03', '004': '04', '005': '05',
'006': '06-07', '007': '06-07', '008': '08-09', '009': '08-09',
'010': '10-12', '011': '10-12', '012': '10-12', '013': '13-15',
'014': '13-15', '015': '13-15', '016': '16-20', '017': '16-20',
'018': '16-20', '019': '16-20', '020': '16-20', '021': '21-25',
'022': '21-25', '023': '21-25', '024': '21-25', '025': '21-25',
'026': '26-30', '027': '26-30', '028': '26-30', '029': '26-30',
'030': '26-30', '031': '31-35', '032': '31-35', '033': '31-35',
'034': '31-35', '035': '31-35', '036': '36-40', '037': '36-40',
'038': '36-40', '039': '36-40', '040': '36-40', '041': '41-45',
'042': '41-45', '043': '41-45', '044': '41-45', '045': '41-45',
'046': '46-50', '047': '46-50', '048': '46-50', '049': '46-50',
'050': '46-50', '051': '51-55', '052': '51-55', '053': '51-55',
'054': '51-55', '055': '51-55', '056': '56-60', '057': '56-60',
'058': '56-60', '059': '56-60', '060': '56-60', '061': '61-65',
'062': '61-65', '063': '61-65', '064': '61-65', '065': '61-65',
'066': '66-70', '067': '66-70', '068': '66-70', '069': '66-70',
'070': '66-70', '071': '71-80', '072': '71-80', '073': '71-80',
'074': '71-80', '075': '71-80', '076': '71-80', '077': '71-80',
'078': '71-80', '079': '71-80', '080': '71-80', '081': '81-90',
'082': '81-90', '083': '81-90', '084': '81-90', '085': '81-90',
'086': '81-90', '087': '81-90', '088': '81-90', '089': '81-90',
'090': '81-90', '091': '90+', '092': '90+', '093': '90+',
'095': '90+', '096': '90+', '099': '90+', '100': '90+',
'101': '90+', '110': '90+'}
file_names, labels = [], []
data_path = Path(self.config.data_path)
for file in tqdm(sorted(data_path.glob('*/*.*'))):
label = file.parent.name
labels.append(label_dict[label])
file_names.append(str(file))
df = pd.DataFrame.from_dict({"image": file_names, "label": labels})
ros = RandomOverSampler(random_state=self.config.random_state)
df_resampled, y_resampled = ros.fit_resample(df[['image']], df['label'])
df = pd.concat([df_resampled, y_resampled], axis=1)
dataset = Dataset.from_pandas(df).cast_column("image", Image())
labels_list = sorted(list(set(labels)))
label2id = {label: i for i, label in enumerate(labels_list)}
id2label = {i: label for i, label in enumerate(labels_list)}
ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)
dataset = dataset.map(lambda x: {'label': ClassLabels.str2int(x['label'])}, batched=True)
dataset = dataset.cast_column('label', ClassLabels)
split_dataset = dataset.train_test_split(test_size=self.config.test_split_size, shuffle=True, stratify_by_column="label")
return split_dataset, id2label, label2id
def train(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
split_dataset, id2label, label2id = self._prepare_data()
train_data = split_dataset['train']
test_data = split_dataset['test']
processor = EfficientFormerImageProcessor.from_pretrained(self.config.model_name)
# Define base transforms (without normalization)
_train_transforms = Compose([
Resize((self.config.image_size, self.config.image_size)),
RandomRotation(15),
RandomHorizontalFlip(0.5),
ToTensor(),
])
_val_transforms = Compose([
Resize((self.config.image_size, self.config.image_size)),
ToTensor(),
])
# Use functools.partial to create specialized versions of our top-level function
# This is a pickle-safe way to pass extra arguments (processor, transforms)
train_transform_func = partial(apply_transforms, processor=processor, transform_pipeline=_train_transforms)
val_transform_func = partial(apply_transforms, processor=processor, transform_pipeline=_val_transforms)
train_data.set_transform(train_transform_func)
test_data.set_transform(val_transform_func)
model = EfficientFormerForImageClassification.from_pretrained(
self.config.model_name,
num_labels=len(id2label),
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True
).to(device)
args = TrainingArguments(
output_dir=self.config.root_dir,
logging_dir=f'{self.config.root_dir}/logs',
evaluation_strategy="epoch",
learning_rate=self.config.learning_rate,
per_device_train_batch_size=self.config.batch_size,
per_device_eval_batch_size=self.config.batch_size,
num_train_epochs=self.config.num_train_epochs,
weight_decay=self.config.weight_decay,
warmup_steps=self.config.warmup_steps,
save_strategy='epoch',
load_best_model_at_end=True,
metric_for_best_model="accuracy",
save_total_limit=1,
remove_unused_columns=False,
dataloader_num_workers=4,
report_to="none"
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_data,
eval_dataset=test_data,
data_collator=collate_fn,
compute_metrics=compute_metrics,
tokenizer=processor,
)
trainer.train()
logger.info(f"Saving best model to {self.config.trained_model_path}")
trainer.save_model(self.config.trained_model_path) |