DamageLensAI / src /training /train_fusion.py
junaid17's picture
Upload 43 files
eef8873 verified
Raw
History Blame Contribute Delete
2.39 kB
import logging
import torch.nn as nn
from torch.optim import AdamW
from src.config import DEVICE, EPOCHS, NUM_CLASSES
from src.models.fusion_model import FusionClassifier
from src.data.dataset import create_fusion_dataloaders
from src.training.trainer import train_dual_input_model
logger = logging.getLogger(__name__)
def run_fusion_training():
logger.info("Initializing Fusion training pipeline...")
train_loader, eval_loader = create_fusion_dataloaders()
model = FusionClassifier(
num_classes=NUM_CLASSES
).to(DEVICE)
criterion = nn.CrossEntropyLoss(
label_smoothing=0.1
)
optimizer = AdamW([
# EfficientNet unfrozen blocks
{
"params": model.eff_features[5].parameters(),
"lr": 1e-5
},
{
"params": model.eff_features[6].parameters(),
"lr": 3e-5
},
{
"params": model.eff_features[7].parameters(),
"lr": 3e-5
},
# ConvNeXt unfrozen blocks
{
"params": model.cnx_backbone.encoder.stages[2].parameters(),
"lr": 3e-5
},
{
"params": model.cnx_backbone.encoder.stages[3].parameters(),
"lr": 3e-5
},
{
"params": model.cnx_backbone.layernorm.parameters(),
"lr": 3e-5
},
# Fusion head
{
"params": model.fusion_head.parameters(),
"lr": 1e-4
}
], weight_decay=1e-4)
logger.info("Starting Fusion training...")
all_preds, all_labels = train_dual_input_model(
model=model,
train_loader=train_loader,
eval_loader=eval_loader,
optimizer=optimizer,
criterion=criterion,
device=DEVICE,
epochs=EPOCHS,
checkpoint_model_name="best_fusion_model",
patience=7
)
logger.info("Fusion training completed.")
return all_preds, all_labels
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
preds, labels = run_fusion_training()
print("\nFusion training completed successfully.")
print("Prediction samples:", preds[:10])
print("Label samples:", labels[:10])