import torch import torch.nn as nn import pandas as pd import numpy as np from torch.utils.data import Dataset from transformers import ( AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer ) from torchvision.transforms import ( Compose, Normalize, RandomRotation, RandomHorizontalFlip, Resize, ToTensor ) from cnnClassifier.entity.config_entity import MultiTaskModelTrainerConfig from cnnClassifier import logger from PIL import Image from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score class MultiTaskEfficientNet(nn.Module): def __init__(self, model_name, num_labels_age, num_labels_gender, num_labels_race): super().__init__() self.efficientnet_base = AutoModelForImageClassification.from_pretrained(model_name, ignore_mismatched_sizes=True) original_classifier = self.efficientnet_base.classifier feature_dim = original_classifier.in_features self.efficientnet_base.classifier = nn.Identity() self.age_classifier = nn.Linear(feature_dim, num_labels_age) self.gender_classifier = nn.Linear(feature_dim, num_labels_gender) self.race_classifier = nn.Linear(feature_dim, num_labels_race) def forward(self, pixel_values, labels=None): features = self.efficientnet_base.efficientnet(pixel_values) pooled_features = features.last_hidden_state.mean(dim=[2, 3]) age_logits = self.age_classifier(pooled_features) gender_logits = self.gender_classifier(pooled_features) race_logits = self.race_classifier(pooled_features) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() age_loss = loss_fct(age_logits, labels[:, 0]) gender_loss = loss_fct(gender_logits, labels[:, 1]) race_loss = loss_fct(race_logits, labels[:, 2]) loss = (2.0 * age_loss) + gender_loss + race_loss return {"loss": loss, "age_logits": age_logits, "gender_logits": gender_logits, "race_logits": race_logits} class FairFaceDataset(Dataset): def __init__(self, dataframe, processor, transforms): self.dataframe = dataframe self.processor = processor self.transforms = transforms self.normalize = Normalize(mean=processor.image_mean, std=processor.image_std) def __len__(self): return len(self.dataframe) def __getitem__(self, idx): row = self.dataframe.iloc[idx] image_path = row['image_file_path'] image = Image.open(image_path).convert("RGB") pixel_values = self.transforms(image) pixel_values = self.normalize(pixel_values) labels = torch.tensor([row['age_id'], row['gender_id'], row['race_id']], dtype=torch.long) return {"pixel_values": pixel_values, "labels": labels} def compute_multitask_metrics(eval_pred): predictions, labels = eval_pred age_logits, gender_logits, race_logits = predictions['age_logits'], predictions['gender_logits'], predictions['race_logits'] age_preds = np.argmax(age_logits, axis=1) gender_preds = np.argmax(gender_logits, axis=1) race_preds = np.argmax(race_logits, axis=1) age_labels, gender_labels, race_labels = labels[:, 0], labels[:, 1], labels[:, 2] age_acc = accuracy_score(age_labels, age_preds) gender_acc = accuracy_score(gender_labels, gender_preds) race_acc = accuracy_score(race_labels, race_preds) overall_acc = (age_acc + gender_acc + race_acc) / 3.0 return {"age_accuracy": age_acc, "gender_accuracy": gender_acc, "race_accuracy": race_acc, "overall_accuracy": overall_acc} class MultiTaskModelTrainer: def __init__(self, config: MultiTaskModelTrainerConfig): self.config = config self.processor = AutoImageProcessor.from_pretrained(config.model_name) def train(self): device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") logger.info("Loading and preparing dataset from cleaned CSV...") df = pd.read_csv(self.config.data_path) label_maps = {} for task in ['age', 'gender', 'race']: labels = sorted(df[task].unique()) label_maps[f'{task}_label2id'] = {label: i for i, label in enumerate(labels)} df[f'{task}_id'] = df[task].map(label_maps[f'{task}_label2id']) num_classes_age = len(label_maps['age_label2id']) num_classes_gender = len(label_maps['gender_label2id']) num_classes_race = len(label_maps['race_label2id']) train_df, test_df = train_test_split(df, test_size=self.config.test_split_size, random_state=self.config.random_state, stratify=df['age']) train_transforms = Compose([ Resize((self.config.image_size, self.config.image_size)), RandomHorizontalFlip(), RandomRotation(10), ToTensor(), # Normalization is now in the Dataset ]) val_transforms = Compose([ Resize((self.config.image_size, self.config.image_size)), ToTensor(), ]) train_dataset = FairFaceDataset(dataframe=train_df, processor=self.processor, transforms=train_transforms) test_dataset = FairFaceDataset(dataframe=test_df, processor=self.processor, transforms=val_transforms) model = MultiTaskEfficientNet(model_name=self.config.model_name, num_labels_age=num_classes_age, num_labels_gender=num_classes_gender, num_labels_race=num_classes_race).to(device) args = TrainingArguments( output_dir=self.config.root_dir, logging_dir=f'{self.config.root_dir}/logs', evaluation_strategy="epoch", learning_rate=self.config.learning_rate, per_device_train_batch_size=self.config.batch_size, per_device_eval_batch_size=self.config.batch_size, num_train_epochs=self.config.num_train_epochs, weight_decay=self.config.weight_decay, save_strategy='epoch', load_best_model_at_end=True, metric_for_best_model="eval_overall_accuracy", dataloader_num_workers=4, lr_scheduler_type='cosine', report_to="none" ) class EvalTrainer(Trainer): def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): has_labels = "labels" in inputs inputs = self._prepare_inputs(inputs) with torch.no_grad(): outputs = model(**inputs) loss = outputs.get("loss") predictions = {"age_logits": outputs["age_logits"], "gender_logits": outputs["gender_logits"], "race_logits": outputs["race_logits"]} return (loss, predictions, inputs["labels"] if has_labels else None) trainer = EvalTrainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=test_dataset, compute_metrics=compute_multitask_metrics) trainer.train() logger.info(f"Saving final model and processor to {self.config.trained_model_path}") trainer.save_model(self.config.trained_model_path) self.processor.save_pretrained(self.config.trained_model_path)